1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file implements the targeting of the Machinelegalizer class for 10 /// AMDGPU. 11 /// \todo This should be generated by TableGen. 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "AMDGPULegalizerInfo.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "SIMachineFunctionInfo.h" 18 19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 20 #include "llvm/CodeGen/TargetOpcodes.h" 21 #include "llvm/CodeGen/ValueTypes.h" 22 #include "llvm/IR/DerivedTypes.h" 23 #include "llvm/IR/Type.h" 24 #include "llvm/Support/Debug.h" 25 26 using namespace llvm; 27 using namespace LegalizeActions; 28 using namespace LegalizeMutations; 29 using namespace LegalityPredicates; 30 31 32 static LegalityPredicate isMultiple32(unsigned TypeIdx, 33 unsigned MaxSize = 512) { 34 return [=](const LegalityQuery &Query) { 35 const LLT Ty = Query.Types[TypeIdx]; 36 const LLT EltTy = Ty.getScalarType(); 37 return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0; 38 }; 39 } 40 41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) { 42 return [=](const LegalityQuery &Query) { 43 const LLT Ty = Query.Types[TypeIdx]; 44 return Ty.isVector() && 45 Ty.getNumElements() % 2 != 0 && 46 Ty.getElementType().getSizeInBits() < 32; 47 }; 48 } 49 50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) { 51 return [=](const LegalityQuery &Query) { 52 const LLT Ty = Query.Types[TypeIdx]; 53 const LLT EltTy = Ty.getElementType(); 54 return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy)); 55 }; 56 } 57 58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) { 59 return [=](const LegalityQuery &Query) { 60 const LLT Ty = Query.Types[TypeIdx]; 61 const LLT EltTy = Ty.getElementType(); 62 unsigned Size = Ty.getSizeInBits(); 63 unsigned Pieces = (Size + 63) / 64; 64 unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces; 65 return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy)); 66 }; 67 } 68 69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) { 70 return [=](const LegalityQuery &Query) { 71 const LLT QueryTy = Query.Types[TypeIdx]; 72 return QueryTy.isVector() && QueryTy.getSizeInBits() > Size; 73 }; 74 } 75 76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) { 77 return [=](const LegalityQuery &Query) { 78 const LLT QueryTy = Query.Types[TypeIdx]; 79 return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0; 80 }; 81 } 82 83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST, 84 const GCNTargetMachine &TM) { 85 using namespace TargetOpcode; 86 87 auto GetAddrSpacePtr = [&TM](unsigned AS) { 88 return LLT::pointer(AS, TM.getPointerSizeInBits(AS)); 89 }; 90 91 const LLT S1 = LLT::scalar(1); 92 const LLT S8 = LLT::scalar(8); 93 const LLT S16 = LLT::scalar(16); 94 const LLT S32 = LLT::scalar(32); 95 const LLT S64 = LLT::scalar(64); 96 const LLT S128 = LLT::scalar(128); 97 const LLT S256 = LLT::scalar(256); 98 const LLT S512 = LLT::scalar(512); 99 100 const LLT V2S16 = LLT::vector(2, 16); 101 const LLT V4S16 = LLT::vector(4, 16); 102 const LLT V8S16 = LLT::vector(8, 16); 103 104 const LLT V2S32 = LLT::vector(2, 32); 105 const LLT V3S32 = LLT::vector(3, 32); 106 const LLT V4S32 = LLT::vector(4, 32); 107 const LLT V5S32 = LLT::vector(5, 32); 108 const LLT V6S32 = LLT::vector(6, 32); 109 const LLT V7S32 = LLT::vector(7, 32); 110 const LLT V8S32 = LLT::vector(8, 32); 111 const LLT V9S32 = LLT::vector(9, 32); 112 const LLT V10S32 = LLT::vector(10, 32); 113 const LLT V11S32 = LLT::vector(11, 32); 114 const LLT V12S32 = LLT::vector(12, 32); 115 const LLT V13S32 = LLT::vector(13, 32); 116 const LLT V14S32 = LLT::vector(14, 32); 117 const LLT V15S32 = LLT::vector(15, 32); 118 const LLT V16S32 = LLT::vector(16, 32); 119 120 const LLT V2S64 = LLT::vector(2, 64); 121 const LLT V3S64 = LLT::vector(3, 64); 122 const LLT V4S64 = LLT::vector(4, 64); 123 const LLT V5S64 = LLT::vector(5, 64); 124 const LLT V6S64 = LLT::vector(6, 64); 125 const LLT V7S64 = LLT::vector(7, 64); 126 const LLT V8S64 = LLT::vector(8, 64); 127 128 std::initializer_list<LLT> AllS32Vectors = 129 {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32, 130 V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32}; 131 std::initializer_list<LLT> AllS64Vectors = 132 {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64}; 133 134 const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS); 135 const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS); 136 const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS); 137 const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS); 138 const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS); 139 140 const LLT CodePtr = FlatPtr; 141 142 const std::initializer_list<LLT> AddrSpaces64 = { 143 GlobalPtr, ConstantPtr, FlatPtr 144 }; 145 146 const std::initializer_list<LLT> AddrSpaces32 = { 147 LocalPtr, PrivatePtr 148 }; 149 150 setAction({G_BRCOND, S1}, Legal); 151 152 // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more 153 // elements for v3s16 154 getActionDefinitionsBuilder(G_PHI) 155 .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256}) 156 .legalFor(AllS32Vectors) 157 .legalFor(AllS64Vectors) 158 .legalFor(AddrSpaces64) 159 .legalFor(AddrSpaces32) 160 .clampScalar(0, S32, S256) 161 .widenScalarToNextPow2(0, 32) 162 .clampMaxNumElements(0, S32, 16) 163 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 164 .legalIf(isPointer(0)); 165 166 167 getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH}) 168 .legalFor({S32}) 169 .clampScalar(0, S32, S32) 170 .scalarize(0); 171 172 // Report legal for any types we can handle anywhere. For the cases only legal 173 // on the SALU, RegBankSelect will be able to re-legalize. 174 getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) 175 .legalFor({S32, S1, S64, V2S32, V2S16, V4S16}) 176 .clampScalar(0, S32, S64) 177 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 178 .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0)) 179 .widenScalarToNextPow2(0) 180 .scalarize(0); 181 182 getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO, 183 G_UADDE, G_SADDE, G_USUBE, G_SSUBE}) 184 .legalFor({{S32, S1}}) 185 .clampScalar(0, S32, S32); 186 187 getActionDefinitionsBuilder(G_BITCAST) 188 .legalForCartesianProduct({S32, V2S16}) 189 .legalForCartesianProduct({S64, V2S32, V4S16}) 190 .legalForCartesianProduct({V2S64, V4S32}) 191 // Don't worry about the size constraint. 192 .legalIf(all(isPointer(0), isPointer(1))); 193 194 if (ST.has16BitInsts()) { 195 getActionDefinitionsBuilder(G_FCONSTANT) 196 .legalFor({S32, S64, S16}) 197 .clampScalar(0, S16, S64); 198 } else { 199 getActionDefinitionsBuilder(G_FCONSTANT) 200 .legalFor({S32, S64}) 201 .clampScalar(0, S32, S64); 202 } 203 204 getActionDefinitionsBuilder(G_IMPLICIT_DEF) 205 .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr, 206 ConstantPtr, LocalPtr, FlatPtr, PrivatePtr}) 207 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 208 .clampScalarOrElt(0, S32, S512) 209 .legalIf(isMultiple32(0)) 210 .widenScalarToNextPow2(0, 32) 211 .clampMaxNumElements(0, S32, 16); 212 213 214 // FIXME: i1 operands to intrinsics should always be legal, but other i1 215 // values may not be legal. We need to figure out how to distinguish 216 // between these two scenarios. 217 getActionDefinitionsBuilder(G_CONSTANT) 218 .legalFor({S1, S32, S64, GlobalPtr, 219 LocalPtr, ConstantPtr, PrivatePtr, FlatPtr }) 220 .clampScalar(0, S32, S64) 221 .widenScalarToNextPow2(0) 222 .legalIf(isPointer(0)); 223 224 setAction({G_FRAME_INDEX, PrivatePtr}, Legal); 225 226 auto &FPOpActions = getActionDefinitionsBuilder( 227 { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE}) 228 .legalFor({S32, S64}); 229 230 if (ST.has16BitInsts()) { 231 if (ST.hasVOP3PInsts()) 232 FPOpActions.legalFor({S16, V2S16}); 233 else 234 FPOpActions.legalFor({S16}); 235 } 236 237 if (ST.hasVOP3PInsts()) 238 FPOpActions.clampMaxNumElements(0, S16, 2); 239 FPOpActions 240 .scalarize(0) 241 .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64); 242 243 if (ST.has16BitInsts()) { 244 getActionDefinitionsBuilder(G_FSQRT) 245 .legalFor({S32, S64, S16}) 246 .scalarize(0) 247 .clampScalar(0, S16, S64); 248 } else { 249 getActionDefinitionsBuilder(G_FSQRT) 250 .legalFor({S32, S64}) 251 .scalarize(0) 252 .clampScalar(0, S32, S64); 253 } 254 255 getActionDefinitionsBuilder(G_FPTRUNC) 256 .legalFor({{S32, S64}, {S16, S32}}) 257 .scalarize(0); 258 259 getActionDefinitionsBuilder(G_FPEXT) 260 .legalFor({{S64, S32}, {S32, S16}}) 261 .lowerFor({{S64, S16}}) // FIXME: Implement 262 .scalarize(0); 263 264 getActionDefinitionsBuilder(G_FSUB) 265 // Use actual fsub instruction 266 .legalFor({S32}) 267 // Must use fadd + fneg 268 .lowerFor({S64, S16, V2S16}) 269 .scalarize(0) 270 .clampScalar(0, S32, S64); 271 272 getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT}) 273 .legalFor({{S64, S32}, {S32, S16}, {S64, S16}, 274 {S32, S1}, {S64, S1}, {S16, S1}, 275 // FIXME: Hack 276 {S64, LLT::scalar(33)}, 277 {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}}) 278 .scalarize(0); 279 280 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 281 .legalFor({{S32, S32}, {S64, S32}}) 282 .scalarize(0); 283 284 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 285 .legalFor({{S32, S32}, {S32, S64}}) 286 .scalarize(0); 287 288 getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND}) 289 .legalFor({S32, S64}) 290 .scalarize(0); 291 292 293 getActionDefinitionsBuilder(G_GEP) 294 .legalForCartesianProduct(AddrSpaces64, {S64}) 295 .legalForCartesianProduct(AddrSpaces32, {S32}) 296 .scalarize(0); 297 298 setAction({G_BLOCK_ADDR, CodePtr}, Legal); 299 300 getActionDefinitionsBuilder(G_ICMP) 301 .legalForCartesianProduct( 302 {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr}) 303 .legalFor({{S1, S32}, {S1, S64}}) 304 .widenScalarToNextPow2(1) 305 .clampScalar(1, S32, S64) 306 .scalarize(0) 307 .legalIf(all(typeIs(0, S1), isPointer(1))); 308 309 getActionDefinitionsBuilder(G_FCMP) 310 .legalFor({{S1, S32}, {S1, S64}}) 311 .widenScalarToNextPow2(1) 312 .clampScalar(1, S32, S64) 313 .scalarize(0); 314 315 // FIXME: fexp, flog2, flog10 needs to be custom lowered. 316 getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2, 317 G_FLOG, G_FLOG2, G_FLOG10}) 318 .legalFor({S32}) 319 .scalarize(0); 320 321 // The 64-bit versions produce 32-bit results, but only on the SALU. 322 getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF, 323 G_CTTZ, G_CTTZ_ZERO_UNDEF, 324 G_CTPOP}) 325 .legalFor({{S32, S32}, {S32, S64}}) 326 .clampScalar(0, S32, S32) 327 .clampScalar(1, S32, S64) 328 .scalarize(0) 329 .widenScalarToNextPow2(0, 32) 330 .widenScalarToNextPow2(1, 32); 331 332 // TODO: Expand for > s32 333 getActionDefinitionsBuilder(G_BSWAP) 334 .legalFor({S32}) 335 .clampScalar(0, S32, S32) 336 .scalarize(0); 337 338 339 auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) { 340 return [=](const LegalityQuery &Query) { 341 return Query.Types[TypeIdx0].getSizeInBits() < 342 Query.Types[TypeIdx1].getSizeInBits(); 343 }; 344 }; 345 346 auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) { 347 return [=](const LegalityQuery &Query) { 348 return Query.Types[TypeIdx0].getSizeInBits() > 349 Query.Types[TypeIdx1].getSizeInBits(); 350 }; 351 }; 352 353 getActionDefinitionsBuilder(G_INTTOPTR) 354 // List the common cases 355 .legalForCartesianProduct(AddrSpaces64, {S64}) 356 .legalForCartesianProduct(AddrSpaces32, {S32}) 357 .scalarize(0) 358 // Accept any address space as long as the size matches 359 .legalIf(sameSize(0, 1)) 360 .widenScalarIf(smallerThan(1, 0), 361 [](const LegalityQuery &Query) { 362 return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); 363 }) 364 .narrowScalarIf(greaterThan(1, 0), 365 [](const LegalityQuery &Query) { 366 return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); 367 }); 368 369 getActionDefinitionsBuilder(G_PTRTOINT) 370 // List the common cases 371 .legalForCartesianProduct(AddrSpaces64, {S64}) 372 .legalForCartesianProduct(AddrSpaces32, {S32}) 373 .scalarize(0) 374 // Accept any address space as long as the size matches 375 .legalIf(sameSize(0, 1)) 376 .widenScalarIf(smallerThan(0, 1), 377 [](const LegalityQuery &Query) { 378 return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); 379 }) 380 .narrowScalarIf( 381 greaterThan(0, 1), 382 [](const LegalityQuery &Query) { 383 return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); 384 }); 385 386 if (ST.hasFlatAddressSpace()) { 387 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 388 .scalarize(0) 389 .custom(); 390 } 391 392 getActionDefinitionsBuilder({G_LOAD, G_STORE}) 393 .narrowScalarIf([](const LegalityQuery &Query) { 394 unsigned Size = Query.Types[0].getSizeInBits(); 395 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 396 return (Size > 32 && MemSize < Size); 397 }, 398 [](const LegalityQuery &Query) { 399 return std::make_pair(0, LLT::scalar(32)); 400 }) 401 .fewerElementsIf([=, &ST](const LegalityQuery &Query) { 402 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 403 return (MemSize == 96) && 404 Query.Types[0].isVector() && 405 ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS; 406 }, 407 [=](const LegalityQuery &Query) { 408 return std::make_pair(0, V2S32); 409 }) 410 .legalIf([=, &ST](const LegalityQuery &Query) { 411 const LLT &Ty0 = Query.Types[0]; 412 413 unsigned Size = Ty0.getSizeInBits(); 414 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 415 if (Size < 32 || (Size > 32 && MemSize < Size)) 416 return false; 417 418 if (Ty0.isVector() && Size != MemSize) 419 return false; 420 421 // TODO: Decompose private loads into 4-byte components. 422 // TODO: Illegal flat loads on SI 423 switch (MemSize) { 424 case 8: 425 case 16: 426 return Size == 32; 427 case 32: 428 case 64: 429 case 128: 430 return true; 431 432 case 96: 433 // XXX hasLoadX3 434 return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS); 435 436 case 256: 437 case 512: 438 // TODO: constant loads 439 default: 440 return false; 441 } 442 }) 443 .clampScalar(0, S32, S64); 444 445 446 // FIXME: Handle alignment requirements. 447 auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD}) 448 .legalForTypesWithMemDesc({ 449 {S32, GlobalPtr, 8, 8}, 450 {S32, GlobalPtr, 16, 8}, 451 {S32, LocalPtr, 8, 8}, 452 {S32, LocalPtr, 16, 8}, 453 {S32, PrivatePtr, 8, 8}, 454 {S32, PrivatePtr, 16, 8}}); 455 if (ST.hasFlatAddressSpace()) { 456 ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8}, 457 {S32, FlatPtr, 16, 8}}); 458 } 459 460 ExtLoads.clampScalar(0, S32, S32) 461 .widenScalarToNextPow2(0) 462 .unsupportedIfMemSizeNotPow2() 463 .lower(); 464 465 auto &Atomics = getActionDefinitionsBuilder( 466 {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB, 467 G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR, 468 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX, 469 G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG}) 470 .legalFor({{S32, GlobalPtr}, {S32, LocalPtr}, 471 {S64, GlobalPtr}, {S64, LocalPtr}}); 472 if (ST.hasFlatAddressSpace()) { 473 Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}}); 474 } 475 476 // TODO: Pointer types, any 32-bit or 64-bit vector 477 getActionDefinitionsBuilder(G_SELECT) 478 .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16, 479 GlobalPtr, LocalPtr, FlatPtr, PrivatePtr, 480 LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1}) 481 .clampScalar(0, S32, S64) 482 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 483 .fewerElementsIf(numElementsNotEven(0), scalarize(0)) 484 .scalarize(1) 485 .clampMaxNumElements(0, S32, 2) 486 .clampMaxNumElements(0, LocalPtr, 2) 487 .clampMaxNumElements(0, PrivatePtr, 2) 488 .scalarize(0) 489 .legalIf(all(isPointer(0), typeIs(1, S1))); 490 491 // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can 492 // be more flexible with the shift amount type. 493 auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR}) 494 .legalFor({{S32, S32}, {S64, S32}}); 495 if (ST.has16BitInsts()) { 496 if (ST.hasVOP3PInsts()) { 497 Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}}) 498 .clampMaxNumElements(0, S16, 2); 499 } else 500 Shifts.legalFor({{S16, S32}, {S16, S16}}); 501 502 Shifts.clampScalar(1, S16, S32); 503 Shifts.clampScalar(0, S16, S64); 504 Shifts.widenScalarToNextPow2(0, 16); 505 } else { 506 // Make sure we legalize the shift amount type first, as the general 507 // expansion for the shifted type will produce much worse code if it hasn't 508 // been truncated already. 509 Shifts.clampScalar(1, S32, S32); 510 Shifts.clampScalar(0, S32, S64); 511 Shifts.widenScalarToNextPow2(0, 32); 512 } 513 Shifts.scalarize(0); 514 515 for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) { 516 unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0; 517 unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1; 518 unsigned IdxTypeIdx = 2; 519 520 getActionDefinitionsBuilder(Op) 521 .legalIf([=](const LegalityQuery &Query) { 522 const LLT &VecTy = Query.Types[VecTypeIdx]; 523 const LLT &IdxTy = Query.Types[IdxTypeIdx]; 524 return VecTy.getSizeInBits() % 32 == 0 && 525 VecTy.getSizeInBits() <= 512 && 526 IdxTy.getSizeInBits() == 32; 527 }) 528 .clampScalar(EltTypeIdx, S32, S64) 529 .clampScalar(VecTypeIdx, S32, S64) 530 .clampScalar(IdxTypeIdx, S32, S32); 531 } 532 533 getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) 534 .unsupportedIf([=](const LegalityQuery &Query) { 535 const LLT &EltTy = Query.Types[1].getElementType(); 536 return Query.Types[0] != EltTy; 537 }); 538 539 for (unsigned Op : {G_EXTRACT, G_INSERT}) { 540 unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0; 541 unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1; 542 543 // FIXME: Doesn't handle extract of illegal sizes. 544 getActionDefinitionsBuilder(Op) 545 .legalIf([=](const LegalityQuery &Query) { 546 const LLT BigTy = Query.Types[BigTyIdx]; 547 const LLT LitTy = Query.Types[LitTyIdx]; 548 return (BigTy.getSizeInBits() % 32 == 0) && 549 (LitTy.getSizeInBits() % 16 == 0); 550 }) 551 .widenScalarIf( 552 [=](const LegalityQuery &Query) { 553 const LLT BigTy = Query.Types[BigTyIdx]; 554 return (BigTy.getScalarSizeInBits() < 16); 555 }, 556 LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16)) 557 .widenScalarIf( 558 [=](const LegalityQuery &Query) { 559 const LLT LitTy = Query.Types[LitTyIdx]; 560 return (LitTy.getScalarSizeInBits() < 16); 561 }, 562 LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16)) 563 .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx)); 564 } 565 566 // TODO: vectors of pointers 567 getActionDefinitionsBuilder(G_BUILD_VECTOR) 568 .legalForCartesianProduct(AllS32Vectors, {S32}) 569 .legalForCartesianProduct(AllS64Vectors, {S64}) 570 .clampNumElements(0, V16S32, V16S32) 571 .clampNumElements(0, V2S64, V8S64) 572 .minScalarSameAs(1, 0) 573 // FIXME: Sort of a hack to make progress on other legalizations. 574 .legalIf([=](const LegalityQuery &Query) { 575 return Query.Types[0].getScalarSizeInBits() <= 32 || 576 Query.Types[0].getScalarSizeInBits() == 64; 577 }); 578 579 // TODO: Support any combination of v2s32 580 getActionDefinitionsBuilder(G_CONCAT_VECTORS) 581 .legalFor({{V4S32, V2S32}, 582 {V8S32, V2S32}, 583 {V8S32, V4S32}, 584 {V4S64, V2S64}, 585 {V4S16, V2S16}, 586 {V8S16, V2S16}, 587 {V8S16, V4S16}, 588 {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)}, 589 {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}}); 590 591 // Merge/Unmerge 592 for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) { 593 unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1; 594 unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0; 595 596 auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) { 597 const LLT &Ty = Query.Types[TypeIdx]; 598 if (Ty.isVector()) { 599 const LLT &EltTy = Ty.getElementType(); 600 if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64) 601 return true; 602 if (!isPowerOf2_32(EltTy.getSizeInBits())) 603 return true; 604 } 605 return false; 606 }; 607 608 getActionDefinitionsBuilder(Op) 609 .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16) 610 // Clamp the little scalar to s8-s256 and make it a power of 2. It's not 611 // worth considering the multiples of 64 since 2*192 and 2*384 are not 612 // valid. 613 .clampScalar(LitTyIdx, S16, S256) 614 .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32) 615 616 // Break up vectors with weird elements into scalars 617 .fewerElementsIf( 618 [=](const LegalityQuery &Query) { return notValidElt(Query, 0); }, 619 scalarize(0)) 620 .fewerElementsIf( 621 [=](const LegalityQuery &Query) { return notValidElt(Query, 1); }, 622 scalarize(1)) 623 .clampScalar(BigTyIdx, S32, S512) 624 .widenScalarIf( 625 [=](const LegalityQuery &Query) { 626 const LLT &Ty = Query.Types[BigTyIdx]; 627 return !isPowerOf2_32(Ty.getSizeInBits()) && 628 Ty.getSizeInBits() % 16 != 0; 629 }, 630 [=](const LegalityQuery &Query) { 631 // Pick the next power of 2, or a multiple of 64 over 128. 632 // Whichever is smaller. 633 const LLT &Ty = Query.Types[BigTyIdx]; 634 unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1); 635 if (NewSizeInBits >= 256) { 636 unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1); 637 if (RoundedTo < NewSizeInBits) 638 NewSizeInBits = RoundedTo; 639 } 640 return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits)); 641 }) 642 .legalIf([=](const LegalityQuery &Query) { 643 const LLT &BigTy = Query.Types[BigTyIdx]; 644 const LLT &LitTy = Query.Types[LitTyIdx]; 645 646 if (BigTy.isVector() && BigTy.getSizeInBits() < 32) 647 return false; 648 if (LitTy.isVector() && LitTy.getSizeInBits() < 32) 649 return false; 650 651 return BigTy.getSizeInBits() % 16 == 0 && 652 LitTy.getSizeInBits() % 16 == 0 && 653 BigTy.getSizeInBits() <= 512; 654 }) 655 // Any vectors left are the wrong size. Scalarize them. 656 .scalarize(0) 657 .scalarize(1); 658 } 659 660 computeTables(); 661 verify(*ST.getInstrInfo()); 662 } 663 664 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI, 665 MachineRegisterInfo &MRI, 666 MachineIRBuilder &MIRBuilder, 667 GISelChangeObserver &Observer) const { 668 switch (MI.getOpcode()) { 669 case TargetOpcode::G_ADDRSPACE_CAST: 670 return legalizeAddrSpaceCast(MI, MRI, MIRBuilder); 671 default: 672 return false; 673 } 674 675 llvm_unreachable("expected switch to return"); 676 } 677 678 unsigned AMDGPULegalizerInfo::getSegmentAperture( 679 unsigned AS, 680 MachineRegisterInfo &MRI, 681 MachineIRBuilder &MIRBuilder) const { 682 MachineFunction &MF = MIRBuilder.getMF(); 683 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 684 const LLT S32 = LLT::scalar(32); 685 686 if (ST.hasApertureRegs()) { 687 // FIXME: Use inline constants (src_{shared, private}_base) instead of 688 // getreg. 689 unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ? 690 AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE : 691 AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE; 692 unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ? 693 AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE : 694 AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE; 695 unsigned Encoding = 696 AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ | 697 Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ | 698 WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_; 699 700 unsigned ShiftAmt = MRI.createGenericVirtualRegister(S32); 701 unsigned ApertureReg = MRI.createGenericVirtualRegister(S32); 702 unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass); 703 704 MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32) 705 .addDef(GetReg) 706 .addImm(Encoding); 707 MRI.setType(GetReg, S32); 708 709 MIRBuilder.buildConstant(ShiftAmt, WidthM1 + 1); 710 MIRBuilder.buildInstr(TargetOpcode::G_SHL) 711 .addDef(ApertureReg) 712 .addUse(GetReg) 713 .addUse(ShiftAmt); 714 715 return ApertureReg; 716 } 717 718 unsigned QueuePtr = MRI.createGenericVirtualRegister( 719 LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64)); 720 721 // FIXME: Placeholder until we can track the input registers. 722 MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef); 723 724 // Offset into amd_queue_t for group_segment_aperture_base_hi / 725 // private_segment_aperture_base_hi. 726 uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44; 727 728 // FIXME: Don't use undef 729 Value *V = UndefValue::get(PointerType::get( 730 Type::getInt8Ty(MF.getFunction().getContext()), 731 AMDGPUAS::CONSTANT_ADDRESS)); 732 733 MachinePointerInfo PtrInfo(V, StructOffset); 734 MachineMemOperand *MMO = MF.getMachineMemOperand( 735 PtrInfo, 736 MachineMemOperand::MOLoad | 737 MachineMemOperand::MODereferenceable | 738 MachineMemOperand::MOInvariant, 739 4, 740 MinAlign(64, StructOffset)); 741 742 unsigned LoadResult = MRI.createGenericVirtualRegister(S32); 743 unsigned LoadAddr = AMDGPU::NoRegister; 744 745 MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset); 746 MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO); 747 return LoadResult; 748 } 749 750 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast( 751 MachineInstr &MI, MachineRegisterInfo &MRI, 752 MachineIRBuilder &MIRBuilder) const { 753 MachineFunction &MF = MIRBuilder.getMF(); 754 755 MIRBuilder.setInstr(MI); 756 757 unsigned Dst = MI.getOperand(0).getReg(); 758 unsigned Src = MI.getOperand(1).getReg(); 759 760 LLT DstTy = MRI.getType(Dst); 761 LLT SrcTy = MRI.getType(Src); 762 unsigned DestAS = DstTy.getAddressSpace(); 763 unsigned SrcAS = SrcTy.getAddressSpace(); 764 765 // TODO: Avoid reloading from the queue ptr for each cast, or at least each 766 // vector element. 767 assert(!DstTy.isVector()); 768 769 const AMDGPUTargetMachine &TM 770 = static_cast<const AMDGPUTargetMachine &>(MF.getTarget()); 771 772 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 773 if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) { 774 MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST)); 775 return true; 776 } 777 778 if (SrcAS == AMDGPUAS::FLAT_ADDRESS) { 779 assert(DestAS == AMDGPUAS::LOCAL_ADDRESS || 780 DestAS == AMDGPUAS::PRIVATE_ADDRESS); 781 unsigned NullVal = TM.getNullPointerValue(DestAS); 782 783 unsigned SegmentNullReg = MRI.createGenericVirtualRegister(DstTy); 784 unsigned FlatNullReg = MRI.createGenericVirtualRegister(SrcTy); 785 786 MIRBuilder.buildConstant(SegmentNullReg, NullVal); 787 MIRBuilder.buildConstant(FlatNullReg, 0); 788 789 unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy); 790 791 // Extract low 32-bits of the pointer. 792 MIRBuilder.buildExtract(PtrLo32, Src, 0); 793 794 unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1)); 795 MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNullReg); 796 MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNullReg); 797 798 MI.eraseFromParent(); 799 return true; 800 } 801 802 assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS || 803 SrcAS == AMDGPUAS::PRIVATE_ADDRESS); 804 805 unsigned FlatNullReg = MRI.createGenericVirtualRegister(DstTy); 806 unsigned SegmentNullReg = MRI.createGenericVirtualRegister(SrcTy); 807 MIRBuilder.buildConstant(SegmentNullReg, TM.getNullPointerValue(SrcAS)); 808 MIRBuilder.buildConstant(FlatNullReg, TM.getNullPointerValue(DestAS)); 809 810 unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder); 811 812 unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1)); 813 MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNullReg); 814 815 unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy); 816 817 // Coerce the type of the low half of the result so we can use merge_values. 818 unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32)); 819 MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 820 .addDef(SrcAsInt) 821 .addUse(Src); 822 823 // TODO: Should we allow mismatched types but matching sizes in merges to 824 // avoid the ptrtoint? 825 MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg}); 826 MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNullReg); 827 828 MI.eraseFromParent(); 829 return true; 830 } 831