1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file implements the targeting of the Machinelegalizer class for 10 /// AMDGPU. 11 /// \todo This should be generated by TableGen. 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "AMDGPULegalizerInfo.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "SIMachineFunctionInfo.h" 18 19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 20 #include "llvm/CodeGen/TargetOpcodes.h" 21 #include "llvm/CodeGen/ValueTypes.h" 22 #include "llvm/IR/DerivedTypes.h" 23 #include "llvm/IR/Type.h" 24 #include "llvm/Support/Debug.h" 25 26 using namespace llvm; 27 using namespace LegalizeActions; 28 using namespace LegalizeMutations; 29 using namespace LegalityPredicates; 30 31 32 static LegalityPredicate isMultiple32(unsigned TypeIdx, 33 unsigned MaxSize = 512) { 34 return [=](const LegalityQuery &Query) { 35 const LLT Ty = Query.Types[TypeIdx]; 36 const LLT EltTy = Ty.getScalarType(); 37 return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0; 38 }; 39 } 40 41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) { 42 return [=](const LegalityQuery &Query) { 43 const LLT Ty = Query.Types[TypeIdx]; 44 return Ty.isVector() && 45 Ty.getNumElements() % 2 != 0 && 46 Ty.getElementType().getSizeInBits() < 32; 47 }; 48 } 49 50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) { 51 return [=](const LegalityQuery &Query) { 52 const LLT Ty = Query.Types[TypeIdx]; 53 const LLT EltTy = Ty.getElementType(); 54 return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy)); 55 }; 56 } 57 58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) { 59 return [=](const LegalityQuery &Query) { 60 const LLT Ty = Query.Types[TypeIdx]; 61 const LLT EltTy = Ty.getElementType(); 62 unsigned Size = Ty.getSizeInBits(); 63 unsigned Pieces = (Size + 63) / 64; 64 unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces; 65 return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy)); 66 }; 67 } 68 69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) { 70 return [=](const LegalityQuery &Query) { 71 const LLT QueryTy = Query.Types[TypeIdx]; 72 return QueryTy.isVector() && QueryTy.getSizeInBits() > Size; 73 }; 74 } 75 76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) { 77 return [=](const LegalityQuery &Query) { 78 const LLT QueryTy = Query.Types[TypeIdx]; 79 return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0; 80 }; 81 } 82 83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST, 84 const GCNTargetMachine &TM) { 85 using namespace TargetOpcode; 86 87 auto GetAddrSpacePtr = [&TM](unsigned AS) { 88 return LLT::pointer(AS, TM.getPointerSizeInBits(AS)); 89 }; 90 91 const LLT S1 = LLT::scalar(1); 92 const LLT S8 = LLT::scalar(8); 93 const LLT S16 = LLT::scalar(16); 94 const LLT S32 = LLT::scalar(32); 95 const LLT S64 = LLT::scalar(64); 96 const LLT S128 = LLT::scalar(128); 97 const LLT S256 = LLT::scalar(256); 98 const LLT S512 = LLT::scalar(512); 99 100 const LLT V2S16 = LLT::vector(2, 16); 101 const LLT V4S16 = LLT::vector(4, 16); 102 const LLT V8S16 = LLT::vector(8, 16); 103 104 const LLT V2S32 = LLT::vector(2, 32); 105 const LLT V3S32 = LLT::vector(3, 32); 106 const LLT V4S32 = LLT::vector(4, 32); 107 const LLT V5S32 = LLT::vector(5, 32); 108 const LLT V6S32 = LLT::vector(6, 32); 109 const LLT V7S32 = LLT::vector(7, 32); 110 const LLT V8S32 = LLT::vector(8, 32); 111 const LLT V9S32 = LLT::vector(9, 32); 112 const LLT V10S32 = LLT::vector(10, 32); 113 const LLT V11S32 = LLT::vector(11, 32); 114 const LLT V12S32 = LLT::vector(12, 32); 115 const LLT V13S32 = LLT::vector(13, 32); 116 const LLT V14S32 = LLT::vector(14, 32); 117 const LLT V15S32 = LLT::vector(15, 32); 118 const LLT V16S32 = LLT::vector(16, 32); 119 120 const LLT V2S64 = LLT::vector(2, 64); 121 const LLT V3S64 = LLT::vector(3, 64); 122 const LLT V4S64 = LLT::vector(4, 64); 123 const LLT V5S64 = LLT::vector(5, 64); 124 const LLT V6S64 = LLT::vector(6, 64); 125 const LLT V7S64 = LLT::vector(7, 64); 126 const LLT V8S64 = LLT::vector(8, 64); 127 128 std::initializer_list<LLT> AllS32Vectors = 129 {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32, 130 V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32}; 131 std::initializer_list<LLT> AllS64Vectors = 132 {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64}; 133 134 const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS); 135 const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS); 136 const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS); 137 const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS); 138 const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS); 139 140 const LLT CodePtr = FlatPtr; 141 142 const std::initializer_list<LLT> AddrSpaces64 = { 143 GlobalPtr, ConstantPtr, FlatPtr 144 }; 145 146 const std::initializer_list<LLT> AddrSpaces32 = { 147 LocalPtr, PrivatePtr 148 }; 149 150 setAction({G_BRCOND, S1}, Legal); 151 152 // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more 153 // elements for v3s16 154 getActionDefinitionsBuilder(G_PHI) 155 .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256}) 156 .legalFor(AllS32Vectors) 157 .legalFor(AllS64Vectors) 158 .legalFor(AddrSpaces64) 159 .legalFor(AddrSpaces32) 160 .clampScalar(0, S32, S256) 161 .widenScalarToNextPow2(0, 32) 162 .clampMaxNumElements(0, S32, 16) 163 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 164 .legalIf(isPointer(0)); 165 166 167 getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH}) 168 .legalFor({S32}) 169 .clampScalar(0, S32, S32) 170 .scalarize(0); 171 172 // Report legal for any types we can handle anywhere. For the cases only legal 173 // on the SALU, RegBankSelect will be able to re-legalize. 174 getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) 175 .legalFor({S32, S1, S64, V2S32, V2S16, V4S16}) 176 .clampScalar(0, S32, S64) 177 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 178 .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0)) 179 .widenScalarToNextPow2(0) 180 .scalarize(0); 181 182 getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO, 183 G_UADDE, G_SADDE, G_USUBE, G_SSUBE}) 184 .legalFor({{S32, S1}}) 185 .clampScalar(0, S32, S32); 186 187 getActionDefinitionsBuilder(G_BITCAST) 188 .legalForCartesianProduct({S32, V2S16}) 189 .legalForCartesianProduct({S64, V2S32, V4S16}) 190 .legalForCartesianProduct({V2S64, V4S32}) 191 // Don't worry about the size constraint. 192 .legalIf(all(isPointer(0), isPointer(1))); 193 194 if (ST.has16BitInsts()) { 195 getActionDefinitionsBuilder(G_FCONSTANT) 196 .legalFor({S32, S64, S16}) 197 .clampScalar(0, S16, S64); 198 } else { 199 getActionDefinitionsBuilder(G_FCONSTANT) 200 .legalFor({S32, S64}) 201 .clampScalar(0, S32, S64); 202 } 203 204 getActionDefinitionsBuilder(G_IMPLICIT_DEF) 205 .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr, 206 ConstantPtr, LocalPtr, FlatPtr, PrivatePtr}) 207 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 208 .clampScalarOrElt(0, S32, S512) 209 .legalIf(isMultiple32(0)) 210 .widenScalarToNextPow2(0, 32) 211 .clampMaxNumElements(0, S32, 16); 212 213 214 // FIXME: i1 operands to intrinsics should always be legal, but other i1 215 // values may not be legal. We need to figure out how to distinguish 216 // between these two scenarios. 217 getActionDefinitionsBuilder(G_CONSTANT) 218 .legalFor({S1, S32, S64, GlobalPtr, 219 LocalPtr, ConstantPtr, PrivatePtr, FlatPtr }) 220 .clampScalar(0, S32, S64) 221 .widenScalarToNextPow2(0) 222 .legalIf(isPointer(0)); 223 224 setAction({G_FRAME_INDEX, PrivatePtr}, Legal); 225 226 auto &FPOpActions = getActionDefinitionsBuilder( 227 { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE}) 228 .legalFor({S32, S64}); 229 230 if (ST.has16BitInsts()) { 231 if (ST.hasVOP3PInsts()) 232 FPOpActions.legalFor({S16, V2S16}); 233 else 234 FPOpActions.legalFor({S16}); 235 } 236 237 if (ST.hasVOP3PInsts()) 238 FPOpActions.clampMaxNumElements(0, S16, 2); 239 FPOpActions 240 .scalarize(0) 241 .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64); 242 243 if (ST.has16BitInsts()) { 244 getActionDefinitionsBuilder(G_FSQRT) 245 .legalFor({S32, S64, S16}) 246 .scalarize(0) 247 .clampScalar(0, S16, S64); 248 } else { 249 getActionDefinitionsBuilder(G_FSQRT) 250 .legalFor({S32, S64}) 251 .scalarize(0) 252 .clampScalar(0, S32, S64); 253 } 254 255 getActionDefinitionsBuilder(G_FPTRUNC) 256 .legalFor({{S32, S64}, {S16, S32}}) 257 .scalarize(0); 258 259 getActionDefinitionsBuilder(G_FPEXT) 260 .legalFor({{S64, S32}, {S32, S16}}) 261 .lowerFor({{S64, S16}}) // FIXME: Implement 262 .scalarize(0); 263 264 getActionDefinitionsBuilder(G_FSUB) 265 // Use actual fsub instruction 266 .legalFor({S32}) 267 // Must use fadd + fneg 268 .lowerFor({S64, S16, V2S16}) 269 .scalarize(0) 270 .clampScalar(0, S32, S64); 271 272 getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT}) 273 .legalFor({{S64, S32}, {S32, S16}, {S64, S16}, 274 {S32, S1}, {S64, S1}, {S16, S1}, 275 // FIXME: Hack 276 {S64, LLT::scalar(33)}, 277 {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}}) 278 .scalarize(0); 279 280 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 281 .legalFor({{S32, S32}, {S64, S32}}) 282 .scalarize(0); 283 284 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 285 .legalFor({{S32, S32}, {S32, S64}}) 286 .scalarize(0); 287 288 getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND}) 289 .legalFor({S32, S64}) 290 .scalarize(0); 291 292 293 getActionDefinitionsBuilder(G_GEP) 294 .legalForCartesianProduct(AddrSpaces64, {S64}) 295 .legalForCartesianProduct(AddrSpaces32, {S32}) 296 .scalarize(0); 297 298 setAction({G_BLOCK_ADDR, CodePtr}, Legal); 299 300 getActionDefinitionsBuilder(G_ICMP) 301 .legalForCartesianProduct( 302 {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr}) 303 .legalFor({{S1, S32}, {S1, S64}}) 304 .widenScalarToNextPow2(1) 305 .clampScalar(1, S32, S64) 306 .scalarize(0) 307 .legalIf(all(typeIs(0, S1), isPointer(1))); 308 309 getActionDefinitionsBuilder(G_FCMP) 310 .legalFor({{S1, S32}, {S1, S64}}) 311 .widenScalarToNextPow2(1) 312 .clampScalar(1, S32, S64) 313 .scalarize(0); 314 315 // FIXME: fexp, flog2, flog10 needs to be custom lowered. 316 getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2, 317 G_FLOG, G_FLOG2, G_FLOG10}) 318 .legalFor({S32}) 319 .scalarize(0); 320 321 // The 64-bit versions produce 32-bit results, but only on the SALU. 322 getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF, 323 G_CTTZ, G_CTTZ_ZERO_UNDEF, 324 G_CTPOP}) 325 .legalFor({{S32, S32}, {S32, S64}}) 326 .clampScalar(0, S32, S32) 327 .clampScalar(1, S32, S64) 328 .scalarize(0) 329 .widenScalarToNextPow2(0, 32) 330 .widenScalarToNextPow2(1, 32); 331 332 // TODO: Expand for > s32 333 getActionDefinitionsBuilder(G_BSWAP) 334 .legalFor({S32}) 335 .clampScalar(0, S32, S32) 336 .scalarize(0); 337 338 339 auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) { 340 return [=](const LegalityQuery &Query) { 341 return Query.Types[TypeIdx0].getSizeInBits() < 342 Query.Types[TypeIdx1].getSizeInBits(); 343 }; 344 }; 345 346 auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) { 347 return [=](const LegalityQuery &Query) { 348 return Query.Types[TypeIdx0].getSizeInBits() > 349 Query.Types[TypeIdx1].getSizeInBits(); 350 }; 351 }; 352 353 getActionDefinitionsBuilder(G_INTTOPTR) 354 // List the common cases 355 .legalForCartesianProduct(AddrSpaces64, {S64}) 356 .legalForCartesianProduct(AddrSpaces32, {S32}) 357 .scalarize(0) 358 // Accept any address space as long as the size matches 359 .legalIf(sameSize(0, 1)) 360 .widenScalarIf(smallerThan(1, 0), 361 [](const LegalityQuery &Query) { 362 return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); 363 }) 364 .narrowScalarIf(greaterThan(1, 0), 365 [](const LegalityQuery &Query) { 366 return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); 367 }); 368 369 getActionDefinitionsBuilder(G_PTRTOINT) 370 // List the common cases 371 .legalForCartesianProduct(AddrSpaces64, {S64}) 372 .legalForCartesianProduct(AddrSpaces32, {S32}) 373 .scalarize(0) 374 // Accept any address space as long as the size matches 375 .legalIf(sameSize(0, 1)) 376 .widenScalarIf(smallerThan(0, 1), 377 [](const LegalityQuery &Query) { 378 return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); 379 }) 380 .narrowScalarIf( 381 greaterThan(0, 1), 382 [](const LegalityQuery &Query) { 383 return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); 384 }); 385 386 if (ST.hasFlatAddressSpace()) { 387 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 388 .scalarize(0) 389 .custom(); 390 } 391 392 getActionDefinitionsBuilder({G_LOAD, G_STORE}) 393 .narrowScalarIf([](const LegalityQuery &Query) { 394 unsigned Size = Query.Types[0].getSizeInBits(); 395 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 396 return (Size > 32 && MemSize < Size); 397 }, 398 [](const LegalityQuery &Query) { 399 return std::make_pair(0, LLT::scalar(32)); 400 }) 401 .fewerElementsIf([=, &ST](const LegalityQuery &Query) { 402 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 403 return (MemSize == 96) && 404 Query.Types[0].isVector() && 405 ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS; 406 }, 407 [=](const LegalityQuery &Query) { 408 return std::make_pair(0, V2S32); 409 }) 410 .legalIf([=, &ST](const LegalityQuery &Query) { 411 const LLT &Ty0 = Query.Types[0]; 412 413 unsigned Size = Ty0.getSizeInBits(); 414 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 415 if (Size < 32 || (Size > 32 && MemSize < Size)) 416 return false; 417 418 if (Ty0.isVector() && Size != MemSize) 419 return false; 420 421 // TODO: Decompose private loads into 4-byte components. 422 // TODO: Illegal flat loads on SI 423 switch (MemSize) { 424 case 8: 425 case 16: 426 return Size == 32; 427 case 32: 428 case 64: 429 case 128: 430 return true; 431 432 case 96: 433 // XXX hasLoadX3 434 return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS); 435 436 case 256: 437 case 512: 438 // TODO: constant loads 439 default: 440 return false; 441 } 442 }) 443 .clampScalar(0, S32, S64); 444 445 446 // FIXME: Handle alignment requirements. 447 auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD}) 448 .legalForTypesWithMemDesc({ 449 {S32, GlobalPtr, 8, 8}, 450 {S32, GlobalPtr, 16, 8}, 451 {S32, LocalPtr, 8, 8}, 452 {S32, LocalPtr, 16, 8}, 453 {S32, PrivatePtr, 8, 8}, 454 {S32, PrivatePtr, 16, 8}}); 455 if (ST.hasFlatAddressSpace()) { 456 ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8}, 457 {S32, FlatPtr, 16, 8}}); 458 } 459 460 ExtLoads.clampScalar(0, S32, S32) 461 .widenScalarToNextPow2(0) 462 .unsupportedIfMemSizeNotPow2() 463 .lower(); 464 465 auto &Atomics = getActionDefinitionsBuilder( 466 {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB, 467 G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR, 468 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX, 469 G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG}) 470 .legalFor({{S32, GlobalPtr}, {S32, LocalPtr}, 471 {S64, GlobalPtr}, {S64, LocalPtr}}); 472 if (ST.hasFlatAddressSpace()) { 473 Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}}); 474 } 475 476 // TODO: Pointer types, any 32-bit or 64-bit vector 477 getActionDefinitionsBuilder(G_SELECT) 478 .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16, 479 GlobalPtr, LocalPtr, FlatPtr, PrivatePtr, 480 LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1}) 481 .clampScalar(0, S32, S64) 482 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 483 .fewerElementsIf(numElementsNotEven(0), scalarize(0)) 484 .scalarize(1) 485 .clampMaxNumElements(0, S32, 2) 486 .clampMaxNumElements(0, LocalPtr, 2) 487 .clampMaxNumElements(0, PrivatePtr, 2) 488 .scalarize(0) 489 .widenScalarToNextPow2(0) 490 .legalIf(all(isPointer(0), typeIs(1, S1))); 491 492 // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can 493 // be more flexible with the shift amount type. 494 auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR}) 495 .legalFor({{S32, S32}, {S64, S32}}); 496 if (ST.has16BitInsts()) { 497 if (ST.hasVOP3PInsts()) { 498 Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}}) 499 .clampMaxNumElements(0, S16, 2); 500 } else 501 Shifts.legalFor({{S16, S32}, {S16, S16}}); 502 503 Shifts.clampScalar(1, S16, S32); 504 Shifts.clampScalar(0, S16, S64); 505 Shifts.widenScalarToNextPow2(0, 16); 506 } else { 507 // Make sure we legalize the shift amount type first, as the general 508 // expansion for the shifted type will produce much worse code if it hasn't 509 // been truncated already. 510 Shifts.clampScalar(1, S32, S32); 511 Shifts.clampScalar(0, S32, S64); 512 Shifts.widenScalarToNextPow2(0, 32); 513 } 514 Shifts.scalarize(0); 515 516 for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) { 517 unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0; 518 unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1; 519 unsigned IdxTypeIdx = 2; 520 521 getActionDefinitionsBuilder(Op) 522 .legalIf([=](const LegalityQuery &Query) { 523 const LLT &VecTy = Query.Types[VecTypeIdx]; 524 const LLT &IdxTy = Query.Types[IdxTypeIdx]; 525 return VecTy.getSizeInBits() % 32 == 0 && 526 VecTy.getSizeInBits() <= 512 && 527 IdxTy.getSizeInBits() == 32; 528 }) 529 .clampScalar(EltTypeIdx, S32, S64) 530 .clampScalar(VecTypeIdx, S32, S64) 531 .clampScalar(IdxTypeIdx, S32, S32); 532 } 533 534 getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) 535 .unsupportedIf([=](const LegalityQuery &Query) { 536 const LLT &EltTy = Query.Types[1].getElementType(); 537 return Query.Types[0] != EltTy; 538 }); 539 540 for (unsigned Op : {G_EXTRACT, G_INSERT}) { 541 unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0; 542 unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1; 543 544 // FIXME: Doesn't handle extract of illegal sizes. 545 getActionDefinitionsBuilder(Op) 546 .legalIf([=](const LegalityQuery &Query) { 547 const LLT BigTy = Query.Types[BigTyIdx]; 548 const LLT LitTy = Query.Types[LitTyIdx]; 549 return (BigTy.getSizeInBits() % 32 == 0) && 550 (LitTy.getSizeInBits() % 16 == 0); 551 }) 552 .widenScalarIf( 553 [=](const LegalityQuery &Query) { 554 const LLT BigTy = Query.Types[BigTyIdx]; 555 return (BigTy.getScalarSizeInBits() < 16); 556 }, 557 LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16)) 558 .widenScalarIf( 559 [=](const LegalityQuery &Query) { 560 const LLT LitTy = Query.Types[LitTyIdx]; 561 return (LitTy.getScalarSizeInBits() < 16); 562 }, 563 LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16)) 564 .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx)) 565 .widenScalarToNextPow2(BigTyIdx, 32); 566 567 } 568 569 // TODO: vectors of pointers 570 getActionDefinitionsBuilder(G_BUILD_VECTOR) 571 .legalForCartesianProduct(AllS32Vectors, {S32}) 572 .legalForCartesianProduct(AllS64Vectors, {S64}) 573 .clampNumElements(0, V16S32, V16S32) 574 .clampNumElements(0, V2S64, V8S64) 575 .minScalarSameAs(1, 0) 576 // FIXME: Sort of a hack to make progress on other legalizations. 577 .legalIf([=](const LegalityQuery &Query) { 578 return Query.Types[0].getScalarSizeInBits() <= 32 || 579 Query.Types[0].getScalarSizeInBits() == 64; 580 }); 581 582 // TODO: Support any combination of v2s32 583 getActionDefinitionsBuilder(G_CONCAT_VECTORS) 584 .legalFor({{V4S32, V2S32}, 585 {V8S32, V2S32}, 586 {V8S32, V4S32}, 587 {V4S64, V2S64}, 588 {V4S16, V2S16}, 589 {V8S16, V2S16}, 590 {V8S16, V4S16}, 591 {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)}, 592 {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}}); 593 594 // Merge/Unmerge 595 for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) { 596 unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1; 597 unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0; 598 599 auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) { 600 const LLT &Ty = Query.Types[TypeIdx]; 601 if (Ty.isVector()) { 602 const LLT &EltTy = Ty.getElementType(); 603 if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64) 604 return true; 605 if (!isPowerOf2_32(EltTy.getSizeInBits())) 606 return true; 607 } 608 return false; 609 }; 610 611 getActionDefinitionsBuilder(Op) 612 .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16) 613 // Clamp the little scalar to s8-s256 and make it a power of 2. It's not 614 // worth considering the multiples of 64 since 2*192 and 2*384 are not 615 // valid. 616 .clampScalar(LitTyIdx, S16, S256) 617 .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32) 618 619 // Break up vectors with weird elements into scalars 620 .fewerElementsIf( 621 [=](const LegalityQuery &Query) { return notValidElt(Query, 0); }, 622 scalarize(0)) 623 .fewerElementsIf( 624 [=](const LegalityQuery &Query) { return notValidElt(Query, 1); }, 625 scalarize(1)) 626 .clampScalar(BigTyIdx, S32, S512) 627 .widenScalarIf( 628 [=](const LegalityQuery &Query) { 629 const LLT &Ty = Query.Types[BigTyIdx]; 630 return !isPowerOf2_32(Ty.getSizeInBits()) && 631 Ty.getSizeInBits() % 16 != 0; 632 }, 633 [=](const LegalityQuery &Query) { 634 // Pick the next power of 2, or a multiple of 64 over 128. 635 // Whichever is smaller. 636 const LLT &Ty = Query.Types[BigTyIdx]; 637 unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1); 638 if (NewSizeInBits >= 256) { 639 unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1); 640 if (RoundedTo < NewSizeInBits) 641 NewSizeInBits = RoundedTo; 642 } 643 return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits)); 644 }) 645 .legalIf([=](const LegalityQuery &Query) { 646 const LLT &BigTy = Query.Types[BigTyIdx]; 647 const LLT &LitTy = Query.Types[LitTyIdx]; 648 649 if (BigTy.isVector() && BigTy.getSizeInBits() < 32) 650 return false; 651 if (LitTy.isVector() && LitTy.getSizeInBits() < 32) 652 return false; 653 654 return BigTy.getSizeInBits() % 16 == 0 && 655 LitTy.getSizeInBits() % 16 == 0 && 656 BigTy.getSizeInBits() <= 512; 657 }) 658 // Any vectors left are the wrong size. Scalarize them. 659 .scalarize(0) 660 .scalarize(1); 661 } 662 663 computeTables(); 664 verify(*ST.getInstrInfo()); 665 } 666 667 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI, 668 MachineRegisterInfo &MRI, 669 MachineIRBuilder &MIRBuilder, 670 GISelChangeObserver &Observer) const { 671 switch (MI.getOpcode()) { 672 case TargetOpcode::G_ADDRSPACE_CAST: 673 return legalizeAddrSpaceCast(MI, MRI, MIRBuilder); 674 default: 675 return false; 676 } 677 678 llvm_unreachable("expected switch to return"); 679 } 680 681 unsigned AMDGPULegalizerInfo::getSegmentAperture( 682 unsigned AS, 683 MachineRegisterInfo &MRI, 684 MachineIRBuilder &MIRBuilder) const { 685 MachineFunction &MF = MIRBuilder.getMF(); 686 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 687 const LLT S32 = LLT::scalar(32); 688 689 if (ST.hasApertureRegs()) { 690 // FIXME: Use inline constants (src_{shared, private}_base) instead of 691 // getreg. 692 unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ? 693 AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE : 694 AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE; 695 unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ? 696 AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE : 697 AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE; 698 unsigned Encoding = 699 AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ | 700 Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ | 701 WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_; 702 703 unsigned ApertureReg = MRI.createGenericVirtualRegister(S32); 704 unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass); 705 706 MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32) 707 .addDef(GetReg) 708 .addImm(Encoding); 709 MRI.setType(GetReg, S32); 710 711 auto ShiftAmt = MIRBuilder.buildConstant(S32, WidthM1 + 1); 712 MIRBuilder.buildInstr(TargetOpcode::G_SHL) 713 .addDef(ApertureReg) 714 .addUse(GetReg) 715 .addUse(ShiftAmt.getReg(0)); 716 717 return ApertureReg; 718 } 719 720 unsigned QueuePtr = MRI.createGenericVirtualRegister( 721 LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64)); 722 723 // FIXME: Placeholder until we can track the input registers. 724 MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef); 725 726 // Offset into amd_queue_t for group_segment_aperture_base_hi / 727 // private_segment_aperture_base_hi. 728 uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44; 729 730 // FIXME: Don't use undef 731 Value *V = UndefValue::get(PointerType::get( 732 Type::getInt8Ty(MF.getFunction().getContext()), 733 AMDGPUAS::CONSTANT_ADDRESS)); 734 735 MachinePointerInfo PtrInfo(V, StructOffset); 736 MachineMemOperand *MMO = MF.getMachineMemOperand( 737 PtrInfo, 738 MachineMemOperand::MOLoad | 739 MachineMemOperand::MODereferenceable | 740 MachineMemOperand::MOInvariant, 741 4, 742 MinAlign(64, StructOffset)); 743 744 unsigned LoadResult = MRI.createGenericVirtualRegister(S32); 745 unsigned LoadAddr = AMDGPU::NoRegister; 746 747 MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset); 748 MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO); 749 return LoadResult; 750 } 751 752 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast( 753 MachineInstr &MI, MachineRegisterInfo &MRI, 754 MachineIRBuilder &MIRBuilder) const { 755 MachineFunction &MF = MIRBuilder.getMF(); 756 757 MIRBuilder.setInstr(MI); 758 759 unsigned Dst = MI.getOperand(0).getReg(); 760 unsigned Src = MI.getOperand(1).getReg(); 761 762 LLT DstTy = MRI.getType(Dst); 763 LLT SrcTy = MRI.getType(Src); 764 unsigned DestAS = DstTy.getAddressSpace(); 765 unsigned SrcAS = SrcTy.getAddressSpace(); 766 767 // TODO: Avoid reloading from the queue ptr for each cast, or at least each 768 // vector element. 769 assert(!DstTy.isVector()); 770 771 const AMDGPUTargetMachine &TM 772 = static_cast<const AMDGPUTargetMachine &>(MF.getTarget()); 773 774 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 775 if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) { 776 MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST)); 777 return true; 778 } 779 780 if (SrcAS == AMDGPUAS::FLAT_ADDRESS) { 781 assert(DestAS == AMDGPUAS::LOCAL_ADDRESS || 782 DestAS == AMDGPUAS::PRIVATE_ADDRESS); 783 unsigned NullVal = TM.getNullPointerValue(DestAS); 784 785 auto SegmentNull = MIRBuilder.buildConstant(DstTy, NullVal); 786 auto FlatNull = MIRBuilder.buildConstant(SrcTy, 0); 787 788 unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy); 789 790 // Extract low 32-bits of the pointer. 791 MIRBuilder.buildExtract(PtrLo32, Src, 0); 792 793 unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1)); 794 MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNull.getReg(0)); 795 MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNull.getReg(0)); 796 797 MI.eraseFromParent(); 798 return true; 799 } 800 801 assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS || 802 SrcAS == AMDGPUAS::PRIVATE_ADDRESS); 803 804 auto SegmentNull = 805 MIRBuilder.buildConstant(SrcTy, TM.getNullPointerValue(SrcAS)); 806 auto FlatNull = 807 MIRBuilder.buildConstant(DstTy, TM.getNullPointerValue(DestAS)); 808 809 unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder); 810 811 unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1)); 812 MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNull.getReg(0)); 813 814 unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy); 815 816 // Coerce the type of the low half of the result so we can use merge_values. 817 unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32)); 818 MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 819 .addDef(SrcAsInt) 820 .addUse(Src); 821 822 // TODO: Should we allow mismatched types but matching sizes in merges to 823 // avoid the ptrtoint? 824 MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg}); 825 MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNull.getReg(0)); 826 827 MI.eraseFromParent(); 828 return true; 829 } 830