1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for
10 /// AMDGPU.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULegalizerInfo.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "SIMachineFunctionInfo.h"
18 
19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 #include "llvm/CodeGen/ValueTypes.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace llvm;
27 using namespace LegalizeActions;
28 using namespace LegalizeMutations;
29 using namespace LegalityPredicates;
30 
31 
32 static LegalityPredicate isMultiple32(unsigned TypeIdx,
33                                       unsigned MaxSize = 512) {
34   return [=](const LegalityQuery &Query) {
35     const LLT Ty = Query.Types[TypeIdx];
36     const LLT EltTy = Ty.getScalarType();
37     return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0;
38   };
39 }
40 
41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) {
42   return [=](const LegalityQuery &Query) {
43     const LLT Ty = Query.Types[TypeIdx];
44     return Ty.isVector() &&
45            Ty.getNumElements() % 2 != 0 &&
46            Ty.getElementType().getSizeInBits() < 32;
47   };
48 }
49 
50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) {
51   return [=](const LegalityQuery &Query) {
52     const LLT Ty = Query.Types[TypeIdx];
53     const LLT EltTy = Ty.getElementType();
54     return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy));
55   };
56 }
57 
58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
59   return [=](const LegalityQuery &Query) {
60     const LLT Ty = Query.Types[TypeIdx];
61     const LLT EltTy = Ty.getElementType();
62     unsigned Size = Ty.getSizeInBits();
63     unsigned Pieces = (Size + 63) / 64;
64     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
65     return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
66   };
67 }
68 
69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) {
70   return [=](const LegalityQuery &Query) {
71     const LLT QueryTy = Query.Types[TypeIdx];
72     return QueryTy.isVector() && QueryTy.getSizeInBits() > Size;
73   };
74 }
75 
76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
77   return [=](const LegalityQuery &Query) {
78     const LLT QueryTy = Query.Types[TypeIdx];
79     return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0;
80   };
81 }
82 
83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST,
84                                          const GCNTargetMachine &TM) {
85   using namespace TargetOpcode;
86 
87   auto GetAddrSpacePtr = [&TM](unsigned AS) {
88     return LLT::pointer(AS, TM.getPointerSizeInBits(AS));
89   };
90 
91   const LLT S1 = LLT::scalar(1);
92   const LLT S8 = LLT::scalar(8);
93   const LLT S16 = LLT::scalar(16);
94   const LLT S32 = LLT::scalar(32);
95   const LLT S64 = LLT::scalar(64);
96   const LLT S128 = LLT::scalar(128);
97   const LLT S256 = LLT::scalar(256);
98   const LLT S512 = LLT::scalar(512);
99 
100   const LLT V2S16 = LLT::vector(2, 16);
101   const LLT V4S16 = LLT::vector(4, 16);
102   const LLT V8S16 = LLT::vector(8, 16);
103 
104   const LLT V2S32 = LLT::vector(2, 32);
105   const LLT V3S32 = LLT::vector(3, 32);
106   const LLT V4S32 = LLT::vector(4, 32);
107   const LLT V5S32 = LLT::vector(5, 32);
108   const LLT V6S32 = LLT::vector(6, 32);
109   const LLT V7S32 = LLT::vector(7, 32);
110   const LLT V8S32 = LLT::vector(8, 32);
111   const LLT V9S32 = LLT::vector(9, 32);
112   const LLT V10S32 = LLT::vector(10, 32);
113   const LLT V11S32 = LLT::vector(11, 32);
114   const LLT V12S32 = LLT::vector(12, 32);
115   const LLT V13S32 = LLT::vector(13, 32);
116   const LLT V14S32 = LLT::vector(14, 32);
117   const LLT V15S32 = LLT::vector(15, 32);
118   const LLT V16S32 = LLT::vector(16, 32);
119 
120   const LLT V2S64 = LLT::vector(2, 64);
121   const LLT V3S64 = LLT::vector(3, 64);
122   const LLT V4S64 = LLT::vector(4, 64);
123   const LLT V5S64 = LLT::vector(5, 64);
124   const LLT V6S64 = LLT::vector(6, 64);
125   const LLT V7S64 = LLT::vector(7, 64);
126   const LLT V8S64 = LLT::vector(8, 64);
127 
128   std::initializer_list<LLT> AllS32Vectors =
129     {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
130      V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32};
131   std::initializer_list<LLT> AllS64Vectors =
132     {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64};
133 
134   const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS);
135   const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS);
136   const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS);
137   const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS);
138   const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS);
139 
140   const LLT CodePtr = FlatPtr;
141 
142   const std::initializer_list<LLT> AddrSpaces64 = {
143     GlobalPtr, ConstantPtr, FlatPtr
144   };
145 
146   const std::initializer_list<LLT> AddrSpaces32 = {
147     LocalPtr, PrivatePtr
148   };
149 
150   setAction({G_BRCOND, S1}, Legal);
151 
152   // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more
153   // elements for v3s16
154   getActionDefinitionsBuilder(G_PHI)
155     .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256})
156     .legalFor(AllS32Vectors)
157     .legalFor(AllS64Vectors)
158     .legalFor(AddrSpaces64)
159     .legalFor(AddrSpaces32)
160     .clampScalar(0, S32, S256)
161     .widenScalarToNextPow2(0, 32)
162     .clampMaxNumElements(0, S32, 16)
163     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
164     .legalIf(isPointer(0));
165 
166 
167   getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH})
168     .legalFor({S32})
169     .clampScalar(0, S32, S32)
170     .scalarize(0);
171 
172   // Report legal for any types we can handle anywhere. For the cases only legal
173   // on the SALU, RegBankSelect will be able to re-legalize.
174   getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
175     .legalFor({S32, S1, S64, V2S32, V2S16, V4S16})
176     .clampScalar(0, S32, S64)
177     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
178     .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0))
179     .widenScalarToNextPow2(0)
180     .scalarize(0);
181 
182   getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO,
183                                G_UADDE, G_SADDE, G_USUBE, G_SSUBE})
184     .legalFor({{S32, S1}})
185     .clampScalar(0, S32, S32);
186 
187   getActionDefinitionsBuilder(G_BITCAST)
188     .legalForCartesianProduct({S32, V2S16})
189     .legalForCartesianProduct({S64, V2S32, V4S16})
190     .legalForCartesianProduct({V2S64, V4S32})
191     // Don't worry about the size constraint.
192     .legalIf(all(isPointer(0), isPointer(1)));
193 
194   if (ST.has16BitInsts()) {
195     getActionDefinitionsBuilder(G_FCONSTANT)
196       .legalFor({S32, S64, S16})
197       .clampScalar(0, S16, S64);
198   } else {
199     getActionDefinitionsBuilder(G_FCONSTANT)
200       .legalFor({S32, S64})
201       .clampScalar(0, S32, S64);
202   }
203 
204   getActionDefinitionsBuilder(G_IMPLICIT_DEF)
205     .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr,
206                ConstantPtr, LocalPtr, FlatPtr, PrivatePtr})
207     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
208     .clampScalarOrElt(0, S32, S512)
209     .legalIf(isMultiple32(0))
210     .widenScalarToNextPow2(0, 32)
211     .clampMaxNumElements(0, S32, 16);
212 
213 
214   // FIXME: i1 operands to intrinsics should always be legal, but other i1
215   // values may not be legal.  We need to figure out how to distinguish
216   // between these two scenarios.
217   getActionDefinitionsBuilder(G_CONSTANT)
218     .legalFor({S1, S32, S64, GlobalPtr,
219                LocalPtr, ConstantPtr, PrivatePtr, FlatPtr })
220     .clampScalar(0, S32, S64)
221     .widenScalarToNextPow2(0)
222     .legalIf(isPointer(0));
223 
224   setAction({G_FRAME_INDEX, PrivatePtr}, Legal);
225 
226   auto &FPOpActions = getActionDefinitionsBuilder(
227     { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE})
228     .legalFor({S32, S64});
229 
230   if (ST.has16BitInsts()) {
231     if (ST.hasVOP3PInsts())
232       FPOpActions.legalFor({S16, V2S16});
233     else
234       FPOpActions.legalFor({S16});
235   }
236 
237   if (ST.hasVOP3PInsts())
238     FPOpActions.clampMaxNumElements(0, S16, 2);
239   FPOpActions
240     .scalarize(0)
241     .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64);
242 
243   if (ST.has16BitInsts()) {
244     getActionDefinitionsBuilder(G_FSQRT)
245       .legalFor({S32, S64, S16})
246       .scalarize(0)
247       .clampScalar(0, S16, S64);
248   } else {
249     getActionDefinitionsBuilder(G_FSQRT)
250       .legalFor({S32, S64})
251       .scalarize(0)
252       .clampScalar(0, S32, S64);
253   }
254 
255   getActionDefinitionsBuilder(G_FPTRUNC)
256     .legalFor({{S32, S64}, {S16, S32}})
257     .scalarize(0);
258 
259   getActionDefinitionsBuilder(G_FPEXT)
260     .legalFor({{S64, S32}, {S32, S16}})
261     .lowerFor({{S64, S16}}) // FIXME: Implement
262     .scalarize(0);
263 
264   getActionDefinitionsBuilder(G_FCOPYSIGN)
265     .legalForCartesianProduct({S16, S32, S64}, {S16, S32, S64})
266     .scalarize(0);
267 
268   getActionDefinitionsBuilder(G_FSUB)
269       // Use actual fsub instruction
270       .legalFor({S32})
271       // Must use fadd + fneg
272       .lowerFor({S64, S16, V2S16})
273       .scalarize(0)
274       .clampScalar(0, S32, S64);
275 
276   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
277     .legalFor({{S64, S32}, {S32, S16}, {S64, S16},
278                {S32, S1}, {S64, S1}, {S16, S1},
279                // FIXME: Hack
280                {S64, LLT::scalar(33)},
281                {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}})
282     .scalarize(0);
283 
284   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
285     .legalFor({{S32, S32}, {S64, S32}})
286     .lowerFor({{S32, S64}})
287     .customFor({{S64, S64}})
288     .scalarize(0);
289 
290   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
291     .legalFor({{S32, S32}, {S32, S64}})
292     .scalarize(0);
293 
294   getActionDefinitionsBuilder(G_INTRINSIC_ROUND)
295     .legalFor({S32, S64})
296     .scalarize(0);
297 
298   if (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS) {
299     getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_FCEIL, G_FRINT})
300       .legalFor({S32, S64})
301       .clampScalar(0, S32, S64)
302       .scalarize(0);
303   } else {
304     getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_FCEIL, G_FRINT})
305       .legalFor({S32})
306       .customFor({S64})
307       .clampScalar(0, S32, S64)
308       .scalarize(0);
309   }
310 
311   getActionDefinitionsBuilder(G_GEP)
312     .legalForCartesianProduct(AddrSpaces64, {S64})
313     .legalForCartesianProduct(AddrSpaces32, {S32})
314     .scalarize(0);
315 
316   setAction({G_BLOCK_ADDR, CodePtr}, Legal);
317 
318   getActionDefinitionsBuilder(G_ICMP)
319     .legalForCartesianProduct(
320       {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr})
321     .legalFor({{S1, S32}, {S1, S64}})
322     .widenScalarToNextPow2(1)
323     .clampScalar(1, S32, S64)
324     .scalarize(0)
325     .legalIf(all(typeIs(0, S1), isPointer(1)));
326 
327   getActionDefinitionsBuilder(G_FCMP)
328     .legalFor({{S1, S32}, {S1, S64}})
329     .widenScalarToNextPow2(1)
330     .clampScalar(1, S32, S64)
331     .scalarize(0);
332 
333   // FIXME: fexp, flog2, flog10 needs to be custom lowered.
334   getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2,
335                                G_FLOG, G_FLOG2, G_FLOG10})
336     .legalFor({S32})
337     .scalarize(0);
338 
339   // The 64-bit versions produce 32-bit results, but only on the SALU.
340   getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF,
341                                G_CTTZ, G_CTTZ_ZERO_UNDEF,
342                                G_CTPOP})
343     .legalFor({{S32, S32}, {S32, S64}})
344     .clampScalar(0, S32, S32)
345     .clampScalar(1, S32, S64)
346     .scalarize(0)
347     .widenScalarToNextPow2(0, 32)
348     .widenScalarToNextPow2(1, 32);
349 
350   // TODO: Expand for > s32
351   getActionDefinitionsBuilder(G_BSWAP)
352     .legalFor({S32})
353     .clampScalar(0, S32, S32)
354     .scalarize(0);
355 
356   if (ST.has16BitInsts()) {
357     if (ST.hasVOP3PInsts()) {
358       getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
359         .legalFor({S32, S16, V2S16})
360         .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
361         .clampMaxNumElements(0, S16, 2)
362         .clampScalar(0, S16, S32)
363         .widenScalarToNextPow2(0)
364         .scalarize(0);
365     } else {
366       getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
367         .legalFor({S32, S16})
368         .widenScalarToNextPow2(0)
369         .clampScalar(0, S16, S32)
370         .scalarize(0);
371     }
372   } else {
373     getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
374       .legalFor({S32})
375       .clampScalar(0, S32, S32)
376       .widenScalarToNextPow2(0)
377       .scalarize(0);
378   }
379 
380   auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
381     return [=](const LegalityQuery &Query) {
382       return Query.Types[TypeIdx0].getSizeInBits() <
383              Query.Types[TypeIdx1].getSizeInBits();
384     };
385   };
386 
387   auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
388     return [=](const LegalityQuery &Query) {
389       return Query.Types[TypeIdx0].getSizeInBits() >
390              Query.Types[TypeIdx1].getSizeInBits();
391     };
392   };
393 
394   getActionDefinitionsBuilder(G_INTTOPTR)
395     // List the common cases
396     .legalForCartesianProduct(AddrSpaces64, {S64})
397     .legalForCartesianProduct(AddrSpaces32, {S32})
398     .scalarize(0)
399     // Accept any address space as long as the size matches
400     .legalIf(sameSize(0, 1))
401     .widenScalarIf(smallerThan(1, 0),
402       [](const LegalityQuery &Query) {
403         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
404       })
405     .narrowScalarIf(greaterThan(1, 0),
406       [](const LegalityQuery &Query) {
407         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
408       });
409 
410   getActionDefinitionsBuilder(G_PTRTOINT)
411     // List the common cases
412     .legalForCartesianProduct(AddrSpaces64, {S64})
413     .legalForCartesianProduct(AddrSpaces32, {S32})
414     .scalarize(0)
415     // Accept any address space as long as the size matches
416     .legalIf(sameSize(0, 1))
417     .widenScalarIf(smallerThan(0, 1),
418       [](const LegalityQuery &Query) {
419         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
420       })
421     .narrowScalarIf(
422       greaterThan(0, 1),
423       [](const LegalityQuery &Query) {
424         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
425       });
426 
427   if (ST.hasFlatAddressSpace()) {
428     getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
429       .scalarize(0)
430       .custom();
431   }
432 
433   getActionDefinitionsBuilder({G_LOAD, G_STORE})
434     .narrowScalarIf([](const LegalityQuery &Query) {
435         unsigned Size = Query.Types[0].getSizeInBits();
436         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
437         return (Size > 32 && MemSize < Size);
438       },
439       [](const LegalityQuery &Query) {
440         return std::make_pair(0, LLT::scalar(32));
441       })
442     .fewerElementsIf([=, &ST](const LegalityQuery &Query) {
443         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
444         return (MemSize == 96) &&
445                Query.Types[0].isVector() &&
446                !ST.hasDwordx3LoadStores();
447       },
448       [=](const LegalityQuery &Query) {
449         return std::make_pair(0, V2S32);
450       })
451     .legalIf([=, &ST](const LegalityQuery &Query) {
452         const LLT &Ty0 = Query.Types[0];
453 
454         unsigned Size = Ty0.getSizeInBits();
455         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
456         if (Size < 32 || (Size > 32 && MemSize < Size))
457           return false;
458 
459         if (Ty0.isVector() && Size != MemSize)
460           return false;
461 
462         // TODO: Decompose private loads into 4-byte components.
463         // TODO: Illegal flat loads on SI
464         switch (MemSize) {
465         case 8:
466         case 16:
467           return Size == 32;
468         case 32:
469         case 64:
470         case 128:
471           return true;
472 
473         case 96:
474           return ST.hasDwordx3LoadStores();
475 
476         case 256:
477         case 512:
478           // TODO: constant loads
479         default:
480           return false;
481         }
482       })
483     .clampScalar(0, S32, S64);
484 
485 
486   // FIXME: Handle alignment requirements.
487   auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
488     .legalForTypesWithMemDesc({
489         {S32, GlobalPtr, 8, 8},
490         {S32, GlobalPtr, 16, 8},
491         {S32, LocalPtr, 8, 8},
492         {S32, LocalPtr, 16, 8},
493         {S32, PrivatePtr, 8, 8},
494         {S32, PrivatePtr, 16, 8}});
495   if (ST.hasFlatAddressSpace()) {
496     ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8},
497                                        {S32, FlatPtr, 16, 8}});
498   }
499 
500   ExtLoads.clampScalar(0, S32, S32)
501           .widenScalarToNextPow2(0)
502           .unsupportedIfMemSizeNotPow2()
503           .lower();
504 
505   auto &Atomics = getActionDefinitionsBuilder(
506     {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB,
507      G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR,
508      G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX,
509      G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG})
510     .legalFor({{S32, GlobalPtr}, {S32, LocalPtr},
511                {S64, GlobalPtr}, {S64, LocalPtr}});
512   if (ST.hasFlatAddressSpace()) {
513     Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}});
514   }
515 
516   // TODO: Pointer types, any 32-bit or 64-bit vector
517   getActionDefinitionsBuilder(G_SELECT)
518     .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16,
519           GlobalPtr, LocalPtr, FlatPtr, PrivatePtr,
520           LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1})
521     .clampScalar(0, S32, S64)
522     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
523     .fewerElementsIf(numElementsNotEven(0), scalarize(0))
524     .scalarize(1)
525     .clampMaxNumElements(0, S32, 2)
526     .clampMaxNumElements(0, LocalPtr, 2)
527     .clampMaxNumElements(0, PrivatePtr, 2)
528     .scalarize(0)
529     .widenScalarToNextPow2(0)
530     .legalIf(all(isPointer(0), typeIs(1, S1)));
531 
532   // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can
533   // be more flexible with the shift amount type.
534   auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR})
535     .legalFor({{S32, S32}, {S64, S32}});
536   if (ST.has16BitInsts()) {
537     if (ST.hasVOP3PInsts()) {
538       Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}})
539             .clampMaxNumElements(0, S16, 2);
540     } else
541       Shifts.legalFor({{S16, S32}, {S16, S16}});
542 
543     Shifts.clampScalar(1, S16, S32);
544     Shifts.clampScalar(0, S16, S64);
545     Shifts.widenScalarToNextPow2(0, 16);
546   } else {
547     // Make sure we legalize the shift amount type first, as the general
548     // expansion for the shifted type will produce much worse code if it hasn't
549     // been truncated already.
550     Shifts.clampScalar(1, S32, S32);
551     Shifts.clampScalar(0, S32, S64);
552     Shifts.widenScalarToNextPow2(0, 32);
553   }
554   Shifts.scalarize(0);
555 
556   for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) {
557     unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0;
558     unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1;
559     unsigned IdxTypeIdx = 2;
560 
561     getActionDefinitionsBuilder(Op)
562       .legalIf([=](const LegalityQuery &Query) {
563           const LLT &VecTy = Query.Types[VecTypeIdx];
564           const LLT &IdxTy = Query.Types[IdxTypeIdx];
565           return VecTy.getSizeInBits() % 32 == 0 &&
566             VecTy.getSizeInBits() <= 512 &&
567             IdxTy.getSizeInBits() == 32;
568         })
569       .clampScalar(EltTypeIdx, S32, S64)
570       .clampScalar(VecTypeIdx, S32, S64)
571       .clampScalar(IdxTypeIdx, S32, S32);
572   }
573 
574   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
575     .unsupportedIf([=](const LegalityQuery &Query) {
576         const LLT &EltTy = Query.Types[1].getElementType();
577         return Query.Types[0] != EltTy;
578       });
579 
580   for (unsigned Op : {G_EXTRACT, G_INSERT}) {
581     unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0;
582     unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1;
583 
584     // FIXME: Doesn't handle extract of illegal sizes.
585     getActionDefinitionsBuilder(Op)
586       .legalIf([=](const LegalityQuery &Query) {
587           const LLT BigTy = Query.Types[BigTyIdx];
588           const LLT LitTy = Query.Types[LitTyIdx];
589           return (BigTy.getSizeInBits() % 32 == 0) &&
590                  (LitTy.getSizeInBits() % 16 == 0);
591         })
592       .widenScalarIf(
593         [=](const LegalityQuery &Query) {
594           const LLT BigTy = Query.Types[BigTyIdx];
595           return (BigTy.getScalarSizeInBits() < 16);
596         },
597         LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16))
598       .widenScalarIf(
599         [=](const LegalityQuery &Query) {
600           const LLT LitTy = Query.Types[LitTyIdx];
601           return (LitTy.getScalarSizeInBits() < 16);
602         },
603         LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16))
604       .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
605       .widenScalarToNextPow2(BigTyIdx, 32);
606 
607   }
608 
609   // TODO: vectors of pointers
610   getActionDefinitionsBuilder(G_BUILD_VECTOR)
611       .legalForCartesianProduct(AllS32Vectors, {S32})
612       .legalForCartesianProduct(AllS64Vectors, {S64})
613       .clampNumElements(0, V16S32, V16S32)
614       .clampNumElements(0, V2S64, V8S64)
615       .minScalarSameAs(1, 0)
616       // FIXME: Sort of a hack to make progress on other legalizations.
617       .legalIf([=](const LegalityQuery &Query) {
618         return Query.Types[0].getScalarSizeInBits() <= 32 ||
619                Query.Types[0].getScalarSizeInBits() == 64;
620       });
621 
622   // TODO: Support any combination of v2s32
623   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
624     .legalFor({{V4S32, V2S32},
625                {V8S32, V2S32},
626                {V8S32, V4S32},
627                {V4S64, V2S64},
628                {V4S16, V2S16},
629                {V8S16, V2S16},
630                {V8S16, V4S16},
631                {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)},
632                {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}});
633 
634   // Merge/Unmerge
635   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
636     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
637     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
638 
639     auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) {
640       const LLT &Ty = Query.Types[TypeIdx];
641       if (Ty.isVector()) {
642         const LLT &EltTy = Ty.getElementType();
643         if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64)
644           return true;
645         if (!isPowerOf2_32(EltTy.getSizeInBits()))
646           return true;
647       }
648       return false;
649     };
650 
651     getActionDefinitionsBuilder(Op)
652       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
653       // Clamp the little scalar to s8-s256 and make it a power of 2. It's not
654       // worth considering the multiples of 64 since 2*192 and 2*384 are not
655       // valid.
656       .clampScalar(LitTyIdx, S16, S256)
657       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
658 
659       // Break up vectors with weird elements into scalars
660       .fewerElementsIf(
661         [=](const LegalityQuery &Query) { return notValidElt(Query, 0); },
662         scalarize(0))
663       .fewerElementsIf(
664         [=](const LegalityQuery &Query) { return notValidElt(Query, 1); },
665         scalarize(1))
666       .clampScalar(BigTyIdx, S32, S512)
667       .widenScalarIf(
668         [=](const LegalityQuery &Query) {
669           const LLT &Ty = Query.Types[BigTyIdx];
670           return !isPowerOf2_32(Ty.getSizeInBits()) &&
671                  Ty.getSizeInBits() % 16 != 0;
672         },
673         [=](const LegalityQuery &Query) {
674           // Pick the next power of 2, or a multiple of 64 over 128.
675           // Whichever is smaller.
676           const LLT &Ty = Query.Types[BigTyIdx];
677           unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1);
678           if (NewSizeInBits >= 256) {
679             unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1);
680             if (RoundedTo < NewSizeInBits)
681               NewSizeInBits = RoundedTo;
682           }
683           return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits));
684         })
685       .legalIf([=](const LegalityQuery &Query) {
686           const LLT &BigTy = Query.Types[BigTyIdx];
687           const LLT &LitTy = Query.Types[LitTyIdx];
688 
689           if (BigTy.isVector() && BigTy.getSizeInBits() < 32)
690             return false;
691           if (LitTy.isVector() && LitTy.getSizeInBits() < 32)
692             return false;
693 
694           return BigTy.getSizeInBits() % 16 == 0 &&
695                  LitTy.getSizeInBits() % 16 == 0 &&
696                  BigTy.getSizeInBits() <= 512;
697         })
698       // Any vectors left are the wrong size. Scalarize them.
699       .scalarize(0)
700       .scalarize(1);
701   }
702 
703   computeTables();
704   verify(*ST.getInstrInfo());
705 }
706 
707 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI,
708                                          MachineRegisterInfo &MRI,
709                                          MachineIRBuilder &MIRBuilder,
710                                          GISelChangeObserver &Observer) const {
711   switch (MI.getOpcode()) {
712   case TargetOpcode::G_ADDRSPACE_CAST:
713     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
714   case TargetOpcode::G_FRINT:
715     return legalizeFrint(MI, MRI, MIRBuilder);
716   case TargetOpcode::G_FCEIL:
717     return legalizeFceil(MI, MRI, MIRBuilder);
718   case TargetOpcode::G_INTRINSIC_TRUNC:
719     return legalizeIntrinsicTrunc(MI, MRI, MIRBuilder);
720   case TargetOpcode::G_SITOFP:
721     return legalizeITOFP(MI, MRI, MIRBuilder, true);
722   case TargetOpcode::G_UITOFP:
723     return legalizeITOFP(MI, MRI, MIRBuilder, false);
724   default:
725     return false;
726   }
727 
728   llvm_unreachable("expected switch to return");
729 }
730 
731 unsigned AMDGPULegalizerInfo::getSegmentAperture(
732   unsigned AS,
733   MachineRegisterInfo &MRI,
734   MachineIRBuilder &MIRBuilder) const {
735   MachineFunction &MF = MIRBuilder.getMF();
736   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
737   const LLT S32 = LLT::scalar(32);
738 
739   if (ST.hasApertureRegs()) {
740     // FIXME: Use inline constants (src_{shared, private}_base) instead of
741     // getreg.
742     unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ?
743         AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE :
744         AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE;
745     unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ?
746         AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE :
747         AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE;
748     unsigned Encoding =
749         AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ |
750         Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ |
751         WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_;
752 
753     unsigned ApertureReg = MRI.createGenericVirtualRegister(S32);
754     unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
755 
756     MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32)
757       .addDef(GetReg)
758       .addImm(Encoding);
759     MRI.setType(GetReg, S32);
760 
761     auto ShiftAmt = MIRBuilder.buildConstant(S32, WidthM1 + 1);
762     MIRBuilder.buildInstr(TargetOpcode::G_SHL)
763       .addDef(ApertureReg)
764       .addUse(GetReg)
765       .addUse(ShiftAmt.getReg(0));
766 
767     return ApertureReg;
768   }
769 
770   unsigned QueuePtr = MRI.createGenericVirtualRegister(
771     LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64));
772 
773   // FIXME: Placeholder until we can track the input registers.
774   MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef);
775 
776   // Offset into amd_queue_t for group_segment_aperture_base_hi /
777   // private_segment_aperture_base_hi.
778   uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44;
779 
780   // FIXME: Don't use undef
781   Value *V = UndefValue::get(PointerType::get(
782                                Type::getInt8Ty(MF.getFunction().getContext()),
783                                AMDGPUAS::CONSTANT_ADDRESS));
784 
785   MachinePointerInfo PtrInfo(V, StructOffset);
786   MachineMemOperand *MMO = MF.getMachineMemOperand(
787     PtrInfo,
788     MachineMemOperand::MOLoad |
789     MachineMemOperand::MODereferenceable |
790     MachineMemOperand::MOInvariant,
791     4,
792     MinAlign(64, StructOffset));
793 
794   Register LoadResult = MRI.createGenericVirtualRegister(S32);
795   Register LoadAddr;
796 
797   MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset);
798   MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO);
799   return LoadResult;
800 }
801 
802 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast(
803   MachineInstr &MI, MachineRegisterInfo &MRI,
804   MachineIRBuilder &MIRBuilder) const {
805   MachineFunction &MF = MIRBuilder.getMF();
806 
807   MIRBuilder.setInstr(MI);
808 
809   Register Dst = MI.getOperand(0).getReg();
810   Register Src = MI.getOperand(1).getReg();
811 
812   LLT DstTy = MRI.getType(Dst);
813   LLT SrcTy = MRI.getType(Src);
814   unsigned DestAS = DstTy.getAddressSpace();
815   unsigned SrcAS = SrcTy.getAddressSpace();
816 
817   // TODO: Avoid reloading from the queue ptr for each cast, or at least each
818   // vector element.
819   assert(!DstTy.isVector());
820 
821   const AMDGPUTargetMachine &TM
822     = static_cast<const AMDGPUTargetMachine &>(MF.getTarget());
823 
824   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
825   if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) {
826     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST));
827     return true;
828   }
829 
830   if (SrcAS == AMDGPUAS::FLAT_ADDRESS) {
831     assert(DestAS == AMDGPUAS::LOCAL_ADDRESS ||
832            DestAS == AMDGPUAS::PRIVATE_ADDRESS);
833     unsigned NullVal = TM.getNullPointerValue(DestAS);
834 
835     auto SegmentNull = MIRBuilder.buildConstant(DstTy, NullVal);
836     auto FlatNull = MIRBuilder.buildConstant(SrcTy, 0);
837 
838     Register PtrLo32 = MRI.createGenericVirtualRegister(DstTy);
839 
840     // Extract low 32-bits of the pointer.
841     MIRBuilder.buildExtract(PtrLo32, Src, 0);
842 
843     Register CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
844     MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNull.getReg(0));
845     MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNull.getReg(0));
846 
847     MI.eraseFromParent();
848     return true;
849   }
850 
851   assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS ||
852          SrcAS == AMDGPUAS::PRIVATE_ADDRESS);
853 
854   auto SegmentNull =
855       MIRBuilder.buildConstant(SrcTy, TM.getNullPointerValue(SrcAS));
856   auto FlatNull =
857       MIRBuilder.buildConstant(DstTy, TM.getNullPointerValue(DestAS));
858 
859   Register ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder);
860 
861   Register CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
862   MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNull.getReg(0));
863 
864   Register BuildPtr = MRI.createGenericVirtualRegister(DstTy);
865 
866   // Coerce the type of the low half of the result so we can use merge_values.
867   Register SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32));
868   MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
869     .addDef(SrcAsInt)
870     .addUse(Src);
871 
872   // TODO: Should we allow mismatched types but matching sizes in merges to
873   // avoid the ptrtoint?
874   MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg});
875   MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNull.getReg(0));
876 
877   MI.eraseFromParent();
878   return true;
879 }
880 
881 bool AMDGPULegalizerInfo::legalizeFrint(
882   MachineInstr &MI, MachineRegisterInfo &MRI,
883   MachineIRBuilder &MIRBuilder) const {
884   MIRBuilder.setInstr(MI);
885 
886   Register Src = MI.getOperand(1).getReg();
887   LLT Ty = MRI.getType(Src);
888   assert(Ty.isScalar() && Ty.getSizeInBits() == 64);
889 
890   APFloat C1Val(APFloat::IEEEdouble(), "0x1.0p+52");
891   APFloat C2Val(APFloat::IEEEdouble(), "0x1.fffffffffffffp+51");
892 
893   auto C1 = MIRBuilder.buildFConstant(Ty, C1Val);
894   auto CopySign = MIRBuilder.buildFCopysign(Ty, C1, Src);
895 
896   // TODO: Should this propagate fast-math-flags?
897   auto Tmp1 = MIRBuilder.buildFAdd(Ty, Src, CopySign);
898   auto Tmp2 = MIRBuilder.buildFSub(Ty, Tmp1, CopySign);
899 
900   auto C2 = MIRBuilder.buildFConstant(Ty, C2Val);
901   auto Fabs = MIRBuilder.buildFAbs(Ty, Src);
902 
903   auto Cond = MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, LLT::scalar(1), Fabs, C2);
904   MIRBuilder.buildSelect(MI.getOperand(0).getReg(), Cond, Src, Tmp2);
905   return true;
906 }
907 
908 bool AMDGPULegalizerInfo::legalizeFceil(
909   MachineInstr &MI, MachineRegisterInfo &MRI,
910   MachineIRBuilder &B) const {
911   B.setInstr(MI);
912 
913   const LLT S1 = LLT::scalar(1);
914   const LLT S64 = LLT::scalar(64);
915 
916   Register Src = MI.getOperand(1).getReg();
917   assert(MRI.getType(Src) == S64);
918 
919   // result = trunc(src)
920   // if (src > 0.0 && src != result)
921   //   result += 1.0
922 
923   auto Trunc = B.buildInstr(TargetOpcode::G_INTRINSIC_TRUNC, {S64}, {Src});
924 
925   const auto Zero = B.buildFConstant(S64, 0.0);
926   const auto One = B.buildFConstant(S64, 1.0);
927   auto Lt0 = B.buildFCmp(CmpInst::FCMP_OGT, S1, Src, Zero);
928   auto NeTrunc = B.buildFCmp(CmpInst::FCMP_ONE, S1, Src, Trunc);
929   auto And = B.buildAnd(S1, Lt0, NeTrunc);
930   auto Add = B.buildSelect(S64, And, One, Zero);
931 
932   // TODO: Should this propagate fast-math-flags?
933   B.buildFAdd(MI.getOperand(0).getReg(), Trunc, Add);
934   return true;
935 }
936 
937 static MachineInstrBuilder extractF64Exponent(unsigned Hi,
938                                               MachineIRBuilder &B) {
939   const unsigned FractBits = 52;
940   const unsigned ExpBits = 11;
941   LLT S32 = LLT::scalar(32);
942 
943   auto Const0 = B.buildConstant(S32, FractBits - 32);
944   auto Const1 = B.buildConstant(S32, ExpBits);
945 
946   auto ExpPart = B.buildIntrinsic(Intrinsic::amdgcn_ubfe, {S32}, false)
947     .addUse(Const0.getReg(0))
948     .addUse(Const1.getReg(0));
949 
950   return B.buildSub(S32, ExpPart, B.buildConstant(S32, 1023));
951 }
952 
953 bool AMDGPULegalizerInfo::legalizeIntrinsicTrunc(
954   MachineInstr &MI, MachineRegisterInfo &MRI,
955   MachineIRBuilder &B) const {
956   B.setInstr(MI);
957 
958   const LLT S1 = LLT::scalar(1);
959   const LLT S32 = LLT::scalar(32);
960   const LLT S64 = LLT::scalar(64);
961 
962   Register Src = MI.getOperand(1).getReg();
963   assert(MRI.getType(Src) == S64);
964 
965   // TODO: Should this use extract since the low half is unused?
966   auto Unmerge = B.buildUnmerge({S32, S32}, Src);
967   Register Hi = Unmerge.getReg(1);
968 
969   // Extract the upper half, since this is where we will find the sign and
970   // exponent.
971   auto Exp = extractF64Exponent(Hi, B);
972 
973   const unsigned FractBits = 52;
974 
975   // Extract the sign bit.
976   const auto SignBitMask = B.buildConstant(S32, UINT32_C(1) << 31);
977   auto SignBit = B.buildAnd(S32, Hi, SignBitMask);
978 
979   const auto FractMask = B.buildConstant(S64, (UINT64_C(1) << FractBits) - 1);
980 
981   const auto Zero32 = B.buildConstant(S32, 0);
982 
983   // Extend back to 64-bits.
984   auto SignBit64 = B.buildMerge(S64, {Zero32.getReg(0), SignBit.getReg(0)});
985 
986   auto Shr = B.buildAShr(S64, FractMask, Exp);
987   auto Not = B.buildNot(S64, Shr);
988   auto Tmp0 = B.buildAnd(S64, Src, Not);
989   auto FiftyOne = B.buildConstant(S32, FractBits - 1);
990 
991   auto ExpLt0 = B.buildICmp(CmpInst::ICMP_SLT, S1, Exp, Zero32);
992   auto ExpGt51 = B.buildICmp(CmpInst::ICMP_SGT, S1, Exp, FiftyOne);
993 
994   auto Tmp1 = B.buildSelect(S64, ExpLt0, SignBit64, Tmp0);
995   B.buildSelect(MI.getOperand(0).getReg(), ExpGt51, Src, Tmp1);
996   return true;
997 }
998 
999 bool AMDGPULegalizerInfo::legalizeITOFP(
1000   MachineInstr &MI, MachineRegisterInfo &MRI,
1001   MachineIRBuilder &B, bool Signed) const {
1002   B.setInstr(MI);
1003 
1004   Register Dst = MI.getOperand(0).getReg();
1005   Register Src = MI.getOperand(1).getReg();
1006 
1007   const LLT S64 = LLT::scalar(64);
1008   const LLT S32 = LLT::scalar(32);
1009 
1010   assert(MRI.getType(Src) == S64 && MRI.getType(Dst) == S64);
1011 
1012   auto Unmerge = B.buildUnmerge({S32, S32}, Src);
1013 
1014   auto CvtHi = Signed ?
1015     B.buildSITOFP(S64, Unmerge.getReg(1)) :
1016     B.buildUITOFP(S64, Unmerge.getReg(1));
1017 
1018   auto CvtLo = B.buildUITOFP(S64, Unmerge.getReg(0));
1019 
1020   auto ThirtyTwo = B.buildConstant(S32, 32);
1021   auto LdExp = B.buildIntrinsic(Intrinsic::amdgcn_ldexp, {S64}, false)
1022     .addUse(CvtHi.getReg(0))
1023     .addUse(ThirtyTwo.getReg(0));
1024 
1025   // TODO: Should this propagate fast-math-flags?
1026   B.buildFAdd(Dst, LdExp, CvtLo);
1027   MI.eraseFromParent();
1028   return true;
1029 }
1030