1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for
10 /// AMDGPU.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULegalizerInfo.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "SIMachineFunctionInfo.h"
18 
19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 #include "llvm/CodeGen/ValueTypes.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace llvm;
27 using namespace LegalizeActions;
28 using namespace LegalizeMutations;
29 using namespace LegalityPredicates;
30 
31 
32 static LegalityPredicate isMultiple32(unsigned TypeIdx,
33                                       unsigned MaxSize = 512) {
34   return [=](const LegalityQuery &Query) {
35     const LLT Ty = Query.Types[TypeIdx];
36     const LLT EltTy = Ty.getScalarType();
37     return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0;
38   };
39 }
40 
41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) {
42   return [=](const LegalityQuery &Query) {
43     const LLT Ty = Query.Types[TypeIdx];
44     return Ty.isVector() &&
45            Ty.getNumElements() % 2 != 0 &&
46            Ty.getElementType().getSizeInBits() < 32;
47   };
48 }
49 
50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) {
51   return [=](const LegalityQuery &Query) {
52     const LLT Ty = Query.Types[TypeIdx];
53     const LLT EltTy = Ty.getElementType();
54     return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy));
55   };
56 }
57 
58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
59   return [=](const LegalityQuery &Query) {
60     const LLT Ty = Query.Types[TypeIdx];
61     const LLT EltTy = Ty.getElementType();
62     unsigned Size = Ty.getSizeInBits();
63     unsigned Pieces = (Size + 63) / 64;
64     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
65     return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
66   };
67 }
68 
69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) {
70   return [=](const LegalityQuery &Query) {
71     const LLT QueryTy = Query.Types[TypeIdx];
72     return QueryTy.isVector() && QueryTy.getSizeInBits() > Size;
73   };
74 }
75 
76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
77   return [=](const LegalityQuery &Query) {
78     const LLT QueryTy = Query.Types[TypeIdx];
79     return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0;
80   };
81 }
82 
83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST,
84                                          const GCNTargetMachine &TM) {
85   using namespace TargetOpcode;
86 
87   auto GetAddrSpacePtr = [&TM](unsigned AS) {
88     return LLT::pointer(AS, TM.getPointerSizeInBits(AS));
89   };
90 
91   const LLT S1 = LLT::scalar(1);
92   const LLT S8 = LLT::scalar(8);
93   const LLT S16 = LLT::scalar(16);
94   const LLT S32 = LLT::scalar(32);
95   const LLT S64 = LLT::scalar(64);
96   const LLT S128 = LLT::scalar(128);
97   const LLT S256 = LLT::scalar(256);
98   const LLT S512 = LLT::scalar(512);
99 
100   const LLT V2S16 = LLT::vector(2, 16);
101   const LLT V4S16 = LLT::vector(4, 16);
102   const LLT V8S16 = LLT::vector(8, 16);
103 
104   const LLT V2S32 = LLT::vector(2, 32);
105   const LLT V3S32 = LLT::vector(3, 32);
106   const LLT V4S32 = LLT::vector(4, 32);
107   const LLT V5S32 = LLT::vector(5, 32);
108   const LLT V6S32 = LLT::vector(6, 32);
109   const LLT V7S32 = LLT::vector(7, 32);
110   const LLT V8S32 = LLT::vector(8, 32);
111   const LLT V9S32 = LLT::vector(9, 32);
112   const LLT V10S32 = LLT::vector(10, 32);
113   const LLT V11S32 = LLT::vector(11, 32);
114   const LLT V12S32 = LLT::vector(12, 32);
115   const LLT V13S32 = LLT::vector(13, 32);
116   const LLT V14S32 = LLT::vector(14, 32);
117   const LLT V15S32 = LLT::vector(15, 32);
118   const LLT V16S32 = LLT::vector(16, 32);
119 
120   const LLT V2S64 = LLT::vector(2, 64);
121   const LLT V3S64 = LLT::vector(3, 64);
122   const LLT V4S64 = LLT::vector(4, 64);
123   const LLT V5S64 = LLT::vector(5, 64);
124   const LLT V6S64 = LLT::vector(6, 64);
125   const LLT V7S64 = LLT::vector(7, 64);
126   const LLT V8S64 = LLT::vector(8, 64);
127 
128   std::initializer_list<LLT> AllS32Vectors =
129     {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
130      V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32};
131   std::initializer_list<LLT> AllS64Vectors =
132     {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64};
133 
134   const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS);
135   const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS);
136   const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS);
137   const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS);
138   const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS);
139 
140   const LLT CodePtr = FlatPtr;
141 
142   const std::initializer_list<LLT> AddrSpaces64 = {
143     GlobalPtr, ConstantPtr, FlatPtr
144   };
145 
146   const std::initializer_list<LLT> AddrSpaces32 = {
147     LocalPtr, PrivatePtr
148   };
149 
150   setAction({G_BRCOND, S1}, Legal);
151 
152   getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH})
153     .legalFor({S32})
154     .clampScalar(0, S32, S32)
155     .scalarize(0);
156 
157   // Report legal for any types we can handle anywhere. For the cases only legal
158   // on the SALU, RegBankSelect will be able to re-legalize.
159   getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
160     .legalFor({S32, S1, S64, V2S32, V2S16, V4S16})
161     .clampScalar(0, S32, S64)
162     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
163     .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0))
164     .scalarize(0);
165 
166   getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO,
167                                G_UADDE, G_SADDE, G_USUBE, G_SSUBE})
168     .legalFor({{S32, S1}})
169     .clampScalar(0, S32, S32);
170 
171   getActionDefinitionsBuilder(G_BITCAST)
172     .legalForCartesianProduct({S32, V2S16})
173     .legalForCartesianProduct({S64, V2S32, V4S16})
174     .legalForCartesianProduct({V2S64, V4S32})
175     // Don't worry about the size constraint.
176     .legalIf(all(isPointer(0), isPointer(1)));
177 
178   if (ST.has16BitInsts()) {
179     getActionDefinitionsBuilder(G_FCONSTANT)
180       .legalFor({S32, S64, S16})
181       .clampScalar(0, S16, S64);
182   } else {
183     getActionDefinitionsBuilder(G_FCONSTANT)
184       .legalFor({S32, S64})
185       .clampScalar(0, S32, S64);
186   }
187 
188   getActionDefinitionsBuilder(G_IMPLICIT_DEF)
189     .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr,
190                ConstantPtr, LocalPtr, FlatPtr, PrivatePtr})
191     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
192     .clampScalarOrElt(0, S32, S512)
193     .legalIf(isMultiple32(0))
194     .widenScalarToNextPow2(0, 32);
195 
196 
197   // FIXME: i1 operands to intrinsics should always be legal, but other i1
198   // values may not be legal.  We need to figure out how to distinguish
199   // between these two scenarios.
200   getActionDefinitionsBuilder(G_CONSTANT)
201     .legalFor({S1, S32, S64, GlobalPtr,
202                LocalPtr, ConstantPtr, PrivatePtr, FlatPtr })
203     .clampScalar(0, S32, S64)
204     .widenScalarToNextPow2(0)
205     .legalIf(isPointer(0));
206 
207   setAction({G_FRAME_INDEX, PrivatePtr}, Legal);
208 
209   auto &FPOpActions = getActionDefinitionsBuilder(
210     { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE})
211     .legalFor({S32, S64});
212 
213   if (ST.has16BitInsts()) {
214     if (ST.hasVOP3PInsts())
215       FPOpActions.legalFor({S16, V2S16});
216     else
217       FPOpActions.legalFor({S16});
218   }
219 
220   if (ST.hasVOP3PInsts())
221     FPOpActions.clampMaxNumElements(0, S16, 2);
222   FPOpActions
223     .scalarize(0)
224     .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64);
225 
226   if (ST.has16BitInsts()) {
227     getActionDefinitionsBuilder(G_FSQRT)
228       .legalFor({S32, S64, S16})
229       .scalarize(0)
230       .clampScalar(0, S16, S64);
231   } else {
232     getActionDefinitionsBuilder(G_FSQRT)
233       .legalFor({S32, S64})
234       .scalarize(0)
235       .clampScalar(0, S32, S64);
236   }
237 
238   getActionDefinitionsBuilder(G_FPTRUNC)
239     .legalFor({{S32, S64}, {S16, S32}})
240     .scalarize(0);
241 
242   getActionDefinitionsBuilder(G_FPEXT)
243     .legalFor({{S64, S32}, {S32, S16}})
244     .lowerFor({{S64, S16}}) // FIXME: Implement
245     .scalarize(0);
246 
247   getActionDefinitionsBuilder(G_FSUB)
248       // Use actual fsub instruction
249       .legalFor({S32})
250       // Must use fadd + fneg
251       .lowerFor({S64, S16, V2S16})
252       .scalarize(0)
253       .clampScalar(0, S32, S64);
254 
255   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
256     .legalFor({{S64, S32}, {S32, S16}, {S64, S16},
257                {S32, S1}, {S64, S1}, {S16, S1},
258                // FIXME: Hack
259                {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}})
260     .scalarize(0);
261 
262   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
263     .legalFor({{S32, S32}, {S64, S32}})
264     .scalarize(0);
265 
266   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
267     .legalFor({{S32, S32}, {S32, S64}})
268     .scalarize(0);
269 
270   getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND})
271     .legalFor({S32, S64})
272     .scalarize(0);
273 
274 
275   getActionDefinitionsBuilder(G_GEP)
276     .legalForCartesianProduct(AddrSpaces64, {S64})
277     .legalForCartesianProduct(AddrSpaces32, {S32})
278     .scalarize(0);
279 
280   setAction({G_BLOCK_ADDR, CodePtr}, Legal);
281 
282   getActionDefinitionsBuilder(G_ICMP)
283     .legalForCartesianProduct(
284       {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr})
285     .legalFor({{S1, S32}, {S1, S64}})
286     .widenScalarToNextPow2(1)
287     .clampScalar(1, S32, S64)
288     .scalarize(0)
289     .legalIf(all(typeIs(0, S1), isPointer(1)));
290 
291   getActionDefinitionsBuilder(G_FCMP)
292     .legalFor({{S1, S32}, {S1, S64}})
293     .widenScalarToNextPow2(1)
294     .clampScalar(1, S32, S64)
295     .scalarize(0);
296 
297   // FIXME: fexp, flog2, flog10 needs to be custom lowered.
298   getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2,
299                                G_FLOG, G_FLOG2, G_FLOG10})
300     .legalFor({S32})
301     .scalarize(0);
302 
303   // The 64-bit versions produce 32-bit results, but only on the SALU.
304   getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF,
305                                G_CTTZ, G_CTTZ_ZERO_UNDEF,
306                                G_CTPOP})
307     .legalFor({{S32, S32}, {S32, S64}})
308     .clampScalar(0, S32, S32)
309     .clampScalar(1, S32, S64);
310   // TODO: Scalarize
311 
312   // TODO: Expand for > s32
313   getActionDefinitionsBuilder(G_BSWAP)
314     .legalFor({S32})
315     .clampScalar(0, S32, S32)
316     .scalarize(0);
317 
318 
319   auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
320     return [=](const LegalityQuery &Query) {
321       return Query.Types[TypeIdx0].getSizeInBits() <
322              Query.Types[TypeIdx1].getSizeInBits();
323     };
324   };
325 
326   auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
327     return [=](const LegalityQuery &Query) {
328       return Query.Types[TypeIdx0].getSizeInBits() >
329              Query.Types[TypeIdx1].getSizeInBits();
330     };
331   };
332 
333   getActionDefinitionsBuilder(G_INTTOPTR)
334     // List the common cases
335     .legalForCartesianProduct(AddrSpaces64, {S64})
336     .legalForCartesianProduct(AddrSpaces32, {S32})
337     .scalarize(0)
338     // Accept any address space as long as the size matches
339     .legalIf(sameSize(0, 1))
340     .widenScalarIf(smallerThan(1, 0),
341       [](const LegalityQuery &Query) {
342         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
343       })
344     .narrowScalarIf(greaterThan(1, 0),
345       [](const LegalityQuery &Query) {
346         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
347       });
348 
349   getActionDefinitionsBuilder(G_PTRTOINT)
350     // List the common cases
351     .legalForCartesianProduct(AddrSpaces64, {S64})
352     .legalForCartesianProduct(AddrSpaces32, {S32})
353     .scalarize(0)
354     // Accept any address space as long as the size matches
355     .legalIf(sameSize(0, 1))
356     .widenScalarIf(smallerThan(0, 1),
357       [](const LegalityQuery &Query) {
358         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
359       })
360     .narrowScalarIf(
361       greaterThan(0, 1),
362       [](const LegalityQuery &Query) {
363         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
364       });
365 
366   if (ST.hasFlatAddressSpace()) {
367     getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
368       .scalarize(0)
369       .custom();
370   }
371 
372   getActionDefinitionsBuilder({G_LOAD, G_STORE})
373     .narrowScalarIf([](const LegalityQuery &Query) {
374         unsigned Size = Query.Types[0].getSizeInBits();
375         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
376         return (Size > 32 && MemSize < Size);
377       },
378       [](const LegalityQuery &Query) {
379         return std::make_pair(0, LLT::scalar(32));
380       })
381     .fewerElementsIf([=, &ST](const LegalityQuery &Query) {
382         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
383         return (MemSize == 96) &&
384                Query.Types[0].isVector() &&
385                ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS;
386       },
387       [=](const LegalityQuery &Query) {
388         return std::make_pair(0, V2S32);
389       })
390     .legalIf([=, &ST](const LegalityQuery &Query) {
391         const LLT &Ty0 = Query.Types[0];
392 
393         unsigned Size = Ty0.getSizeInBits();
394         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
395         if (Size < 32 || (Size > 32 && MemSize < Size))
396           return false;
397 
398         if (Ty0.isVector() && Size != MemSize)
399           return false;
400 
401         // TODO: Decompose private loads into 4-byte components.
402         // TODO: Illegal flat loads on SI
403         switch (MemSize) {
404         case 8:
405         case 16:
406           return Size == 32;
407         case 32:
408         case 64:
409         case 128:
410           return true;
411 
412         case 96:
413           // XXX hasLoadX3
414           return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS);
415 
416         case 256:
417         case 512:
418           // TODO: constant loads
419         default:
420           return false;
421         }
422       })
423     .clampScalar(0, S32, S64);
424 
425 
426   // FIXME: Handle alignment requirements.
427   auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
428     .legalForTypesWithMemDesc({
429         {S32, GlobalPtr, 8, 8},
430         {S32, GlobalPtr, 16, 8},
431         {S32, LocalPtr, 8, 8},
432         {S32, LocalPtr, 16, 8},
433         {S32, PrivatePtr, 8, 8},
434         {S32, PrivatePtr, 16, 8}});
435   if (ST.hasFlatAddressSpace()) {
436     ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8},
437                                        {S32, FlatPtr, 16, 8}});
438   }
439 
440   ExtLoads.clampScalar(0, S32, S32)
441           .widenScalarToNextPow2(0)
442           .unsupportedIfMemSizeNotPow2()
443           .lower();
444 
445   auto &Atomics = getActionDefinitionsBuilder(
446     {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB,
447      G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR,
448      G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX,
449      G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG})
450     .legalFor({{S32, GlobalPtr}, {S32, LocalPtr},
451                {S64, GlobalPtr}, {S64, LocalPtr}});
452   if (ST.hasFlatAddressSpace()) {
453     Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}});
454   }
455 
456   // TODO: Pointer types, any 32-bit or 64-bit vector
457   getActionDefinitionsBuilder(G_SELECT)
458     .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16,
459           GlobalPtr, LocalPtr, FlatPtr, PrivatePtr,
460           LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1})
461     .clampScalar(0, S32, S64)
462     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
463     .fewerElementsIf(numElementsNotEven(0), scalarize(0))
464     .scalarize(1)
465     .clampMaxNumElements(0, S32, 2)
466     .clampMaxNumElements(0, LocalPtr, 2)
467     .clampMaxNumElements(0, PrivatePtr, 2)
468     .scalarize(0)
469     .legalIf(all(isPointer(0), typeIs(1, S1)));
470 
471   // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can
472   // be more flexible with the shift amount type.
473   auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR})
474     .legalFor({{S32, S32}, {S64, S32}});
475   if (ST.has16BitInsts()) {
476     if (ST.hasVOP3PInsts()) {
477       Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}})
478             .clampMaxNumElements(0, S16, 2);
479     } else
480       Shifts.legalFor({{S16, S32}, {S16, S16}});
481 
482     Shifts.clampScalar(1, S16, S32);
483     Shifts.clampScalar(0, S16, S64);
484     Shifts.widenScalarToNextPow2(0, 16);
485   } else {
486     // Make sure we legalize the shift amount type first, as the general
487     // expansion for the shifted type will produce much worse code if it hasn't
488     // been truncated already.
489     Shifts.clampScalar(1, S32, S32);
490     Shifts.clampScalar(0, S32, S64);
491     Shifts.widenScalarToNextPow2(0, 32);
492   }
493   Shifts.scalarize(0);
494 
495   for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) {
496     unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0;
497     unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1;
498     unsigned IdxTypeIdx = 2;
499 
500     getActionDefinitionsBuilder(Op)
501       .legalIf([=](const LegalityQuery &Query) {
502           const LLT &VecTy = Query.Types[VecTypeIdx];
503           const LLT &IdxTy = Query.Types[IdxTypeIdx];
504           return VecTy.getSizeInBits() % 32 == 0 &&
505             VecTy.getSizeInBits() <= 512 &&
506             IdxTy.getSizeInBits() == 32;
507         })
508       .clampScalar(EltTypeIdx, S32, S64)
509       .clampScalar(VecTypeIdx, S32, S64)
510       .clampScalar(IdxTypeIdx, S32, S32);
511   }
512 
513   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
514     .unsupportedIf([=](const LegalityQuery &Query) {
515         const LLT &EltTy = Query.Types[1].getElementType();
516         return Query.Types[0] != EltTy;
517       });
518 
519   // FIXME: Doesn't handle extract of illegal sizes.
520   getActionDefinitionsBuilder({G_EXTRACT, G_INSERT})
521       .legalIf([=](const LegalityQuery &Query) {
522         const LLT &Ty0 = Query.Types[0];
523         const LLT &Ty1 = Query.Types[1];
524         return (Ty0.getSizeInBits() % 16 == 0) &&
525                (Ty1.getSizeInBits() % 16 == 0);
526       })
527       .moreElementsIf(isSmallOddVector(1), oneMoreElement(1))
528       .widenScalarIf(
529           [=](const LegalityQuery &Query) {
530             const LLT Ty1 = Query.Types[1];
531             return Ty1.isVector() && Ty1.getScalarSizeInBits() < 16;
532           },
533           LegalizeMutations::widenScalarOrEltToNextPow2(1, 16))
534     .clampScalar(0, S16, S256);
535 
536   // TODO: vectors of pointers
537   getActionDefinitionsBuilder(G_BUILD_VECTOR)
538       .legalForCartesianProduct(AllS32Vectors, {S32})
539       .legalForCartesianProduct(AllS64Vectors, {S64})
540       .clampNumElements(0, V16S32, V16S32)
541       .clampNumElements(0, V2S64, V8S64)
542       .minScalarSameAs(1, 0)
543       // FIXME: Sort of a hack to make progress on other legalizations.
544       .legalIf([=](const LegalityQuery &Query) {
545         return Query.Types[0].getScalarSizeInBits() <= 32 ||
546                Query.Types[0].getScalarSizeInBits() == 64;
547       });
548 
549   // TODO: Support any combination of v2s32
550   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
551     .legalFor({{V4S32, V2S32},
552                {V8S32, V2S32},
553                {V8S32, V4S32},
554                {V4S64, V2S64},
555                {V4S16, V2S16},
556                {V8S16, V2S16},
557                {V8S16, V4S16},
558                {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)},
559                {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}});
560 
561   // Merge/Unmerge
562   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
563     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
564     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
565 
566     auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) {
567       const LLT &Ty = Query.Types[TypeIdx];
568       if (Ty.isVector()) {
569         const LLT &EltTy = Ty.getElementType();
570         if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64)
571           return true;
572         if (!isPowerOf2_32(EltTy.getSizeInBits()))
573           return true;
574       }
575       return false;
576     };
577 
578     getActionDefinitionsBuilder(Op)
579       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
580       // Clamp the little scalar to s8-s256 and make it a power of 2. It's not
581       // worth considering the multiples of 64 since 2*192 and 2*384 are not
582       // valid.
583       .clampScalar(LitTyIdx, S16, S256)
584       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
585 
586       // Break up vectors with weird elements into scalars
587       .fewerElementsIf(
588         [=](const LegalityQuery &Query) { return notValidElt(Query, 0); },
589         scalarize(0))
590       .fewerElementsIf(
591         [=](const LegalityQuery &Query) { return notValidElt(Query, 1); },
592         scalarize(1))
593       .clampScalar(BigTyIdx, S32, S512)
594       .widenScalarIf(
595         [=](const LegalityQuery &Query) {
596           const LLT &Ty = Query.Types[BigTyIdx];
597           return !isPowerOf2_32(Ty.getSizeInBits()) &&
598                  Ty.getSizeInBits() % 16 != 0;
599         },
600         [=](const LegalityQuery &Query) {
601           // Pick the next power of 2, or a multiple of 64 over 128.
602           // Whichever is smaller.
603           const LLT &Ty = Query.Types[BigTyIdx];
604           unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1);
605           if (NewSizeInBits >= 256) {
606             unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1);
607             if (RoundedTo < NewSizeInBits)
608               NewSizeInBits = RoundedTo;
609           }
610           return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits));
611         })
612       .legalIf([=](const LegalityQuery &Query) {
613           const LLT &BigTy = Query.Types[BigTyIdx];
614           const LLT &LitTy = Query.Types[LitTyIdx];
615 
616           if (BigTy.isVector() && BigTy.getSizeInBits() < 32)
617             return false;
618           if (LitTy.isVector() && LitTy.getSizeInBits() < 32)
619             return false;
620 
621           return BigTy.getSizeInBits() % 16 == 0 &&
622                  LitTy.getSizeInBits() % 16 == 0 &&
623                  BigTy.getSizeInBits() <= 512;
624         })
625       // Any vectors left are the wrong size. Scalarize them.
626       .scalarize(0)
627       .scalarize(1);
628   }
629 
630   computeTables();
631   verify(*ST.getInstrInfo());
632 }
633 
634 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI,
635                                          MachineRegisterInfo &MRI,
636                                          MachineIRBuilder &MIRBuilder,
637                                          GISelChangeObserver &Observer) const {
638   switch (MI.getOpcode()) {
639   case TargetOpcode::G_ADDRSPACE_CAST:
640     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
641   default:
642     return false;
643   }
644 
645   llvm_unreachable("expected switch to return");
646 }
647 
648 unsigned AMDGPULegalizerInfo::getSegmentAperture(
649   unsigned AS,
650   MachineRegisterInfo &MRI,
651   MachineIRBuilder &MIRBuilder) const {
652   MachineFunction &MF = MIRBuilder.getMF();
653   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
654   const LLT S32 = LLT::scalar(32);
655 
656   if (ST.hasApertureRegs()) {
657     // FIXME: Use inline constants (src_{shared, private}_base) instead of
658     // getreg.
659     unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ?
660         AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE :
661         AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE;
662     unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ?
663         AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE :
664         AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE;
665     unsigned Encoding =
666         AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ |
667         Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ |
668         WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_;
669 
670     unsigned ShiftAmt = MRI.createGenericVirtualRegister(S32);
671     unsigned ApertureReg = MRI.createGenericVirtualRegister(S32);
672     unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
673 
674     MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32)
675       .addDef(GetReg)
676       .addImm(Encoding);
677     MRI.setType(GetReg, S32);
678 
679     MIRBuilder.buildConstant(ShiftAmt, WidthM1 + 1);
680     MIRBuilder.buildInstr(TargetOpcode::G_SHL)
681       .addDef(ApertureReg)
682       .addUse(GetReg)
683       .addUse(ShiftAmt);
684 
685     return ApertureReg;
686   }
687 
688   unsigned QueuePtr = MRI.createGenericVirtualRegister(
689     LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64));
690 
691   // FIXME: Placeholder until we can track the input registers.
692   MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef);
693 
694   // Offset into amd_queue_t for group_segment_aperture_base_hi /
695   // private_segment_aperture_base_hi.
696   uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44;
697 
698   // FIXME: Don't use undef
699   Value *V = UndefValue::get(PointerType::get(
700                                Type::getInt8Ty(MF.getFunction().getContext()),
701                                AMDGPUAS::CONSTANT_ADDRESS));
702 
703   MachinePointerInfo PtrInfo(V, StructOffset);
704   MachineMemOperand *MMO = MF.getMachineMemOperand(
705     PtrInfo,
706     MachineMemOperand::MOLoad |
707     MachineMemOperand::MODereferenceable |
708     MachineMemOperand::MOInvariant,
709     4,
710     MinAlign(64, StructOffset));
711 
712   unsigned LoadResult = MRI.createGenericVirtualRegister(S32);
713   unsigned LoadAddr = AMDGPU::NoRegister;
714 
715   MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset);
716   MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO);
717   return LoadResult;
718 }
719 
720 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast(
721   MachineInstr &MI, MachineRegisterInfo &MRI,
722   MachineIRBuilder &MIRBuilder) const {
723   MachineFunction &MF = MIRBuilder.getMF();
724 
725   MIRBuilder.setInstr(MI);
726 
727   unsigned Dst = MI.getOperand(0).getReg();
728   unsigned Src = MI.getOperand(1).getReg();
729 
730   LLT DstTy = MRI.getType(Dst);
731   LLT SrcTy = MRI.getType(Src);
732   unsigned DestAS = DstTy.getAddressSpace();
733   unsigned SrcAS = SrcTy.getAddressSpace();
734 
735   // TODO: Avoid reloading from the queue ptr for each cast, or at least each
736   // vector element.
737   assert(!DstTy.isVector());
738 
739   const AMDGPUTargetMachine &TM
740     = static_cast<const AMDGPUTargetMachine &>(MF.getTarget());
741 
742   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
743   if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) {
744     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST));
745     return true;
746   }
747 
748   if (SrcAS == AMDGPUAS::FLAT_ADDRESS) {
749     assert(DestAS == AMDGPUAS::LOCAL_ADDRESS ||
750            DestAS == AMDGPUAS::PRIVATE_ADDRESS);
751     unsigned NullVal = TM.getNullPointerValue(DestAS);
752 
753     unsigned SegmentNullReg = MRI.createGenericVirtualRegister(DstTy);
754     unsigned FlatNullReg = MRI.createGenericVirtualRegister(SrcTy);
755 
756     MIRBuilder.buildConstant(SegmentNullReg, NullVal);
757     MIRBuilder.buildConstant(FlatNullReg, 0);
758 
759     unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy);
760 
761     // Extract low 32-bits of the pointer.
762     MIRBuilder.buildExtract(PtrLo32, Src, 0);
763 
764     unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
765     MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNullReg);
766     MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNullReg);
767 
768     MI.eraseFromParent();
769     return true;
770   }
771 
772   assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS ||
773          SrcAS == AMDGPUAS::PRIVATE_ADDRESS);
774 
775   unsigned FlatNullReg = MRI.createGenericVirtualRegister(DstTy);
776   unsigned SegmentNullReg = MRI.createGenericVirtualRegister(SrcTy);
777   MIRBuilder.buildConstant(SegmentNullReg, TM.getNullPointerValue(SrcAS));
778   MIRBuilder.buildConstant(FlatNullReg, TM.getNullPointerValue(DestAS));
779 
780   unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder);
781 
782   unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
783   MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNullReg);
784 
785   unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy);
786 
787   // Coerce the type of the low half of the result so we can use merge_values.
788   unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32));
789   MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
790     .addDef(SrcAsInt)
791     .addUse(Src);
792 
793   // TODO: Should we allow mismatched types but matching sizes in merges to
794   // avoid the ptrtoint?
795   MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg});
796   MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNullReg);
797 
798   MI.eraseFromParent();
799   return true;
800 }
801