1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file implements the targeting of the Machinelegalizer class for 10 /// AMDGPU. 11 /// \todo This should be generated by TableGen. 12 //===----------------------------------------------------------------------===// 13 14 #include "AMDGPU.h" 15 #include "AMDGPULegalizerInfo.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "SIMachineFunctionInfo.h" 18 19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 20 #include "llvm/CodeGen/TargetOpcodes.h" 21 #include "llvm/CodeGen/ValueTypes.h" 22 #include "llvm/IR/DerivedTypes.h" 23 #include "llvm/IR/Type.h" 24 #include "llvm/Support/Debug.h" 25 26 using namespace llvm; 27 using namespace LegalizeActions; 28 using namespace LegalizeMutations; 29 using namespace LegalityPredicates; 30 31 32 static LegalityPredicate isMultiple32(unsigned TypeIdx, 33 unsigned MaxSize = 512) { 34 return [=](const LegalityQuery &Query) { 35 const LLT Ty = Query.Types[TypeIdx]; 36 const LLT EltTy = Ty.getScalarType(); 37 return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0; 38 }; 39 } 40 41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) { 42 return [=](const LegalityQuery &Query) { 43 const LLT Ty = Query.Types[TypeIdx]; 44 return Ty.isVector() && 45 Ty.getNumElements() % 2 != 0 && 46 Ty.getElementType().getSizeInBits() < 32; 47 }; 48 } 49 50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) { 51 return [=](const LegalityQuery &Query) { 52 const LLT Ty = Query.Types[TypeIdx]; 53 const LLT EltTy = Ty.getElementType(); 54 return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy)); 55 }; 56 } 57 58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) { 59 return [=](const LegalityQuery &Query) { 60 const LLT Ty = Query.Types[TypeIdx]; 61 const LLT EltTy = Ty.getElementType(); 62 unsigned Size = Ty.getSizeInBits(); 63 unsigned Pieces = (Size + 63) / 64; 64 unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces; 65 return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy)); 66 }; 67 } 68 69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) { 70 return [=](const LegalityQuery &Query) { 71 const LLT QueryTy = Query.Types[TypeIdx]; 72 return QueryTy.isVector() && QueryTy.getSizeInBits() > Size; 73 }; 74 } 75 76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) { 77 return [=](const LegalityQuery &Query) { 78 const LLT QueryTy = Query.Types[TypeIdx]; 79 return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0; 80 }; 81 } 82 83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST, 84 const GCNTargetMachine &TM) { 85 using namespace TargetOpcode; 86 87 auto GetAddrSpacePtr = [&TM](unsigned AS) { 88 return LLT::pointer(AS, TM.getPointerSizeInBits(AS)); 89 }; 90 91 const LLT S1 = LLT::scalar(1); 92 const LLT S8 = LLT::scalar(8); 93 const LLT S16 = LLT::scalar(16); 94 const LLT S32 = LLT::scalar(32); 95 const LLT S64 = LLT::scalar(64); 96 const LLT S128 = LLT::scalar(128); 97 const LLT S256 = LLT::scalar(256); 98 const LLT S512 = LLT::scalar(512); 99 100 const LLT V2S16 = LLT::vector(2, 16); 101 const LLT V4S16 = LLT::vector(4, 16); 102 const LLT V8S16 = LLT::vector(8, 16); 103 104 const LLT V2S32 = LLT::vector(2, 32); 105 const LLT V3S32 = LLT::vector(3, 32); 106 const LLT V4S32 = LLT::vector(4, 32); 107 const LLT V5S32 = LLT::vector(5, 32); 108 const LLT V6S32 = LLT::vector(6, 32); 109 const LLT V7S32 = LLT::vector(7, 32); 110 const LLT V8S32 = LLT::vector(8, 32); 111 const LLT V9S32 = LLT::vector(9, 32); 112 const LLT V10S32 = LLT::vector(10, 32); 113 const LLT V11S32 = LLT::vector(11, 32); 114 const LLT V12S32 = LLT::vector(12, 32); 115 const LLT V13S32 = LLT::vector(13, 32); 116 const LLT V14S32 = LLT::vector(14, 32); 117 const LLT V15S32 = LLT::vector(15, 32); 118 const LLT V16S32 = LLT::vector(16, 32); 119 120 const LLT V2S64 = LLT::vector(2, 64); 121 const LLT V3S64 = LLT::vector(3, 64); 122 const LLT V4S64 = LLT::vector(4, 64); 123 const LLT V5S64 = LLT::vector(5, 64); 124 const LLT V6S64 = LLT::vector(6, 64); 125 const LLT V7S64 = LLT::vector(7, 64); 126 const LLT V8S64 = LLT::vector(8, 64); 127 128 std::initializer_list<LLT> AllS32Vectors = 129 {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32, 130 V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32}; 131 std::initializer_list<LLT> AllS64Vectors = 132 {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64}; 133 134 const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS); 135 const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS); 136 const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS); 137 const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS); 138 const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS); 139 140 const LLT CodePtr = FlatPtr; 141 142 const std::initializer_list<LLT> AddrSpaces64 = { 143 GlobalPtr, ConstantPtr, FlatPtr 144 }; 145 146 const std::initializer_list<LLT> AddrSpaces32 = { 147 LocalPtr, PrivatePtr 148 }; 149 150 setAction({G_BRCOND, S1}, Legal); 151 152 getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH}) 153 .legalFor({S32}) 154 .clampScalar(0, S32, S32) 155 .scalarize(0); 156 157 // Report legal for any types we can handle anywhere. For the cases only legal 158 // on the SALU, RegBankSelect will be able to re-legalize. 159 getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) 160 .legalFor({S32, S1, S64, V2S32, V2S16, V4S16}) 161 .clampScalar(0, S32, S64) 162 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 163 .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0)) 164 .scalarize(0); 165 166 getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO, 167 G_UADDE, G_SADDE, G_USUBE, G_SSUBE}) 168 .legalFor({{S32, S1}}) 169 .clampScalar(0, S32, S32); 170 171 getActionDefinitionsBuilder(G_BITCAST) 172 .legalForCartesianProduct({S32, V2S16}) 173 .legalForCartesianProduct({S64, V2S32, V4S16}) 174 .legalForCartesianProduct({V2S64, V4S32}) 175 // Don't worry about the size constraint. 176 .legalIf(all(isPointer(0), isPointer(1))); 177 178 if (ST.has16BitInsts()) { 179 getActionDefinitionsBuilder(G_FCONSTANT) 180 .legalFor({S32, S64, S16}) 181 .clampScalar(0, S16, S64); 182 } else { 183 getActionDefinitionsBuilder(G_FCONSTANT) 184 .legalFor({S32, S64}) 185 .clampScalar(0, S32, S64); 186 } 187 188 getActionDefinitionsBuilder(G_IMPLICIT_DEF) 189 .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr, 190 ConstantPtr, LocalPtr, FlatPtr, PrivatePtr}) 191 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 192 .clampScalarOrElt(0, S32, S512) 193 .legalIf(isMultiple32(0)) 194 .widenScalarToNextPow2(0, 32); 195 196 197 // FIXME: i1 operands to intrinsics should always be legal, but other i1 198 // values may not be legal. We need to figure out how to distinguish 199 // between these two scenarios. 200 getActionDefinitionsBuilder(G_CONSTANT) 201 .legalFor({S1, S32, S64, GlobalPtr, 202 LocalPtr, ConstantPtr, PrivatePtr, FlatPtr }) 203 .clampScalar(0, S32, S64) 204 .widenScalarToNextPow2(0) 205 .legalIf(isPointer(0)); 206 207 setAction({G_FRAME_INDEX, PrivatePtr}, Legal); 208 209 auto &FPOpActions = getActionDefinitionsBuilder( 210 { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE}) 211 .legalFor({S32, S64}); 212 213 if (ST.has16BitInsts()) { 214 if (ST.hasVOP3PInsts()) 215 FPOpActions.legalFor({S16, V2S16}); 216 else 217 FPOpActions.legalFor({S16}); 218 } 219 220 if (ST.hasVOP3PInsts()) 221 FPOpActions.clampMaxNumElements(0, S16, 2); 222 FPOpActions 223 .scalarize(0) 224 .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64); 225 226 if (ST.has16BitInsts()) { 227 getActionDefinitionsBuilder(G_FSQRT) 228 .legalFor({S32, S64, S16}) 229 .scalarize(0) 230 .clampScalar(0, S16, S64); 231 } else { 232 getActionDefinitionsBuilder(G_FSQRT) 233 .legalFor({S32, S64}) 234 .scalarize(0) 235 .clampScalar(0, S32, S64); 236 } 237 238 getActionDefinitionsBuilder(G_FPTRUNC) 239 .legalFor({{S32, S64}, {S16, S32}}) 240 .scalarize(0); 241 242 getActionDefinitionsBuilder(G_FPEXT) 243 .legalFor({{S64, S32}, {S32, S16}}) 244 .lowerFor({{S64, S16}}) // FIXME: Implement 245 .scalarize(0); 246 247 getActionDefinitionsBuilder(G_FSUB) 248 // Use actual fsub instruction 249 .legalFor({S32}) 250 // Must use fadd + fneg 251 .lowerFor({S64, S16, V2S16}) 252 .scalarize(0) 253 .clampScalar(0, S32, S64); 254 255 getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT}) 256 .legalFor({{S64, S32}, {S32, S16}, {S64, S16}, 257 {S32, S1}, {S64, S1}, {S16, S1}, 258 // FIXME: Hack 259 {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}}) 260 .scalarize(0); 261 262 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 263 .legalFor({{S32, S32}, {S64, S32}}) 264 .scalarize(0); 265 266 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 267 .legalFor({{S32, S32}, {S32, S64}}) 268 .scalarize(0); 269 270 getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND}) 271 .legalFor({S32, S64}) 272 .scalarize(0); 273 274 275 getActionDefinitionsBuilder(G_GEP) 276 .legalForCartesianProduct(AddrSpaces64, {S64}) 277 .legalForCartesianProduct(AddrSpaces32, {S32}) 278 .scalarize(0); 279 280 setAction({G_BLOCK_ADDR, CodePtr}, Legal); 281 282 getActionDefinitionsBuilder(G_ICMP) 283 .legalForCartesianProduct( 284 {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr}) 285 .legalFor({{S1, S32}, {S1, S64}}) 286 .widenScalarToNextPow2(1) 287 .clampScalar(1, S32, S64) 288 .scalarize(0) 289 .legalIf(all(typeIs(0, S1), isPointer(1))); 290 291 getActionDefinitionsBuilder(G_FCMP) 292 .legalFor({{S1, S32}, {S1, S64}}) 293 .widenScalarToNextPow2(1) 294 .clampScalar(1, S32, S64) 295 .scalarize(0); 296 297 // FIXME: fexp, flog2, flog10 needs to be custom lowered. 298 getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2, 299 G_FLOG, G_FLOG2, G_FLOG10}) 300 .legalFor({S32}) 301 .scalarize(0); 302 303 // The 64-bit versions produce 32-bit results, but only on the SALU. 304 getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF, 305 G_CTTZ, G_CTTZ_ZERO_UNDEF, 306 G_CTPOP}) 307 .legalFor({{S32, S32}, {S32, S64}}) 308 .clampScalar(0, S32, S32) 309 .clampScalar(1, S32, S64); 310 // TODO: Scalarize 311 312 // TODO: Expand for > s32 313 getActionDefinitionsBuilder(G_BSWAP) 314 .legalFor({S32}) 315 .clampScalar(0, S32, S32) 316 .scalarize(0); 317 318 319 auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) { 320 return [=](const LegalityQuery &Query) { 321 return Query.Types[TypeIdx0].getSizeInBits() < 322 Query.Types[TypeIdx1].getSizeInBits(); 323 }; 324 }; 325 326 auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) { 327 return [=](const LegalityQuery &Query) { 328 return Query.Types[TypeIdx0].getSizeInBits() > 329 Query.Types[TypeIdx1].getSizeInBits(); 330 }; 331 }; 332 333 getActionDefinitionsBuilder(G_INTTOPTR) 334 // List the common cases 335 .legalForCartesianProduct(AddrSpaces64, {S64}) 336 .legalForCartesianProduct(AddrSpaces32, {S32}) 337 .scalarize(0) 338 // Accept any address space as long as the size matches 339 .legalIf(sameSize(0, 1)) 340 .widenScalarIf(smallerThan(1, 0), 341 [](const LegalityQuery &Query) { 342 return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); 343 }) 344 .narrowScalarIf(greaterThan(1, 0), 345 [](const LegalityQuery &Query) { 346 return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); 347 }); 348 349 getActionDefinitionsBuilder(G_PTRTOINT) 350 // List the common cases 351 .legalForCartesianProduct(AddrSpaces64, {S64}) 352 .legalForCartesianProduct(AddrSpaces32, {S32}) 353 .scalarize(0) 354 // Accept any address space as long as the size matches 355 .legalIf(sameSize(0, 1)) 356 .widenScalarIf(smallerThan(0, 1), 357 [](const LegalityQuery &Query) { 358 return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); 359 }) 360 .narrowScalarIf( 361 greaterThan(0, 1), 362 [](const LegalityQuery &Query) { 363 return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); 364 }); 365 366 if (ST.hasFlatAddressSpace()) { 367 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 368 .scalarize(0) 369 .custom(); 370 } 371 372 getActionDefinitionsBuilder({G_LOAD, G_STORE}) 373 .narrowScalarIf([](const LegalityQuery &Query) { 374 unsigned Size = Query.Types[0].getSizeInBits(); 375 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 376 return (Size > 32 && MemSize < Size); 377 }, 378 [](const LegalityQuery &Query) { 379 return std::make_pair(0, LLT::scalar(32)); 380 }) 381 .fewerElementsIf([=, &ST](const LegalityQuery &Query) { 382 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 383 return (MemSize == 96) && 384 Query.Types[0].isVector() && 385 ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS; 386 }, 387 [=](const LegalityQuery &Query) { 388 return std::make_pair(0, V2S32); 389 }) 390 .legalIf([=, &ST](const LegalityQuery &Query) { 391 const LLT &Ty0 = Query.Types[0]; 392 393 unsigned Size = Ty0.getSizeInBits(); 394 unsigned MemSize = Query.MMODescrs[0].SizeInBits; 395 if (Size < 32 || (Size > 32 && MemSize < Size)) 396 return false; 397 398 if (Ty0.isVector() && Size != MemSize) 399 return false; 400 401 // TODO: Decompose private loads into 4-byte components. 402 // TODO: Illegal flat loads on SI 403 switch (MemSize) { 404 case 8: 405 case 16: 406 return Size == 32; 407 case 32: 408 case 64: 409 case 128: 410 return true; 411 412 case 96: 413 // XXX hasLoadX3 414 return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS); 415 416 case 256: 417 case 512: 418 // TODO: constant loads 419 default: 420 return false; 421 } 422 }) 423 .clampScalar(0, S32, S64); 424 425 426 // FIXME: Handle alignment requirements. 427 auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD}) 428 .legalForTypesWithMemDesc({ 429 {S32, GlobalPtr, 8, 8}, 430 {S32, GlobalPtr, 16, 8}, 431 {S32, LocalPtr, 8, 8}, 432 {S32, LocalPtr, 16, 8}, 433 {S32, PrivatePtr, 8, 8}, 434 {S32, PrivatePtr, 16, 8}}); 435 if (ST.hasFlatAddressSpace()) { 436 ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8}, 437 {S32, FlatPtr, 16, 8}}); 438 } 439 440 ExtLoads.clampScalar(0, S32, S32) 441 .widenScalarToNextPow2(0) 442 .unsupportedIfMemSizeNotPow2() 443 .lower(); 444 445 auto &Atomics = getActionDefinitionsBuilder( 446 {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB, 447 G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR, 448 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX, 449 G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG}) 450 .legalFor({{S32, GlobalPtr}, {S32, LocalPtr}, 451 {S64, GlobalPtr}, {S64, LocalPtr}}); 452 if (ST.hasFlatAddressSpace()) { 453 Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}}); 454 } 455 456 // TODO: Pointer types, any 32-bit or 64-bit vector 457 getActionDefinitionsBuilder(G_SELECT) 458 .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16, 459 GlobalPtr, LocalPtr, FlatPtr, PrivatePtr, 460 LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1}) 461 .clampScalar(0, S32, S64) 462 .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) 463 .fewerElementsIf(numElementsNotEven(0), scalarize(0)) 464 .scalarize(1) 465 .clampMaxNumElements(0, S32, 2) 466 .clampMaxNumElements(0, LocalPtr, 2) 467 .clampMaxNumElements(0, PrivatePtr, 2) 468 .scalarize(0) 469 .legalIf(all(isPointer(0), typeIs(1, S1))); 470 471 // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can 472 // be more flexible with the shift amount type. 473 auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR}) 474 .legalFor({{S32, S32}, {S64, S32}}); 475 if (ST.has16BitInsts()) { 476 if (ST.hasVOP3PInsts()) { 477 Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}}) 478 .clampMaxNumElements(0, S16, 2); 479 } else 480 Shifts.legalFor({{S16, S32}, {S16, S16}}); 481 482 Shifts.clampScalar(1, S16, S32); 483 Shifts.clampScalar(0, S16, S64); 484 Shifts.widenScalarToNextPow2(0, 16); 485 } else { 486 // Make sure we legalize the shift amount type first, as the general 487 // expansion for the shifted type will produce much worse code if it hasn't 488 // been truncated already. 489 Shifts.clampScalar(1, S32, S32); 490 Shifts.clampScalar(0, S32, S64); 491 Shifts.widenScalarToNextPow2(0, 32); 492 } 493 Shifts.scalarize(0); 494 495 for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) { 496 unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0; 497 unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1; 498 unsigned IdxTypeIdx = 2; 499 500 getActionDefinitionsBuilder(Op) 501 .legalIf([=](const LegalityQuery &Query) { 502 const LLT &VecTy = Query.Types[VecTypeIdx]; 503 const LLT &IdxTy = Query.Types[IdxTypeIdx]; 504 return VecTy.getSizeInBits() % 32 == 0 && 505 VecTy.getSizeInBits() <= 512 && 506 IdxTy.getSizeInBits() == 32; 507 }) 508 .clampScalar(EltTypeIdx, S32, S64) 509 .clampScalar(VecTypeIdx, S32, S64) 510 .clampScalar(IdxTypeIdx, S32, S32); 511 } 512 513 getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) 514 .unsupportedIf([=](const LegalityQuery &Query) { 515 const LLT &EltTy = Query.Types[1].getElementType(); 516 return Query.Types[0] != EltTy; 517 }); 518 519 // FIXME: Doesn't handle extract of illegal sizes. 520 getActionDefinitionsBuilder({G_EXTRACT, G_INSERT}) 521 .legalIf([=](const LegalityQuery &Query) { 522 const LLT &Ty0 = Query.Types[0]; 523 const LLT &Ty1 = Query.Types[1]; 524 return (Ty0.getSizeInBits() % 16 == 0) && 525 (Ty1.getSizeInBits() % 16 == 0); 526 }) 527 .moreElementsIf(isSmallOddVector(1), oneMoreElement(1)) 528 .widenScalarIf( 529 [=](const LegalityQuery &Query) { 530 const LLT Ty1 = Query.Types[1]; 531 return Ty1.isVector() && Ty1.getScalarSizeInBits() < 16; 532 }, 533 LegalizeMutations::widenScalarOrEltToNextPow2(1, 16)) 534 .clampScalar(0, S16, S256); 535 536 // TODO: vectors of pointers 537 getActionDefinitionsBuilder(G_BUILD_VECTOR) 538 .legalForCartesianProduct(AllS32Vectors, {S32}) 539 .legalForCartesianProduct(AllS64Vectors, {S64}) 540 .clampNumElements(0, V16S32, V16S32) 541 .clampNumElements(0, V2S64, V8S64) 542 .minScalarSameAs(1, 0) 543 // FIXME: Sort of a hack to make progress on other legalizations. 544 .legalIf([=](const LegalityQuery &Query) { 545 return Query.Types[0].getScalarSizeInBits() <= 32 || 546 Query.Types[0].getScalarSizeInBits() == 64; 547 }); 548 549 // TODO: Support any combination of v2s32 550 getActionDefinitionsBuilder(G_CONCAT_VECTORS) 551 .legalFor({{V4S32, V2S32}, 552 {V8S32, V2S32}, 553 {V8S32, V4S32}, 554 {V4S64, V2S64}, 555 {V4S16, V2S16}, 556 {V8S16, V2S16}, 557 {V8S16, V4S16}, 558 {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)}, 559 {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}}); 560 561 // Merge/Unmerge 562 for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) { 563 unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1; 564 unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0; 565 566 auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) { 567 const LLT &Ty = Query.Types[TypeIdx]; 568 if (Ty.isVector()) { 569 const LLT &EltTy = Ty.getElementType(); 570 if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64) 571 return true; 572 if (!isPowerOf2_32(EltTy.getSizeInBits())) 573 return true; 574 } 575 return false; 576 }; 577 578 getActionDefinitionsBuilder(Op) 579 .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16) 580 // Clamp the little scalar to s8-s256 and make it a power of 2. It's not 581 // worth considering the multiples of 64 since 2*192 and 2*384 are not 582 // valid. 583 .clampScalar(LitTyIdx, S16, S256) 584 .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32) 585 586 // Break up vectors with weird elements into scalars 587 .fewerElementsIf( 588 [=](const LegalityQuery &Query) { return notValidElt(Query, 0); }, 589 scalarize(0)) 590 .fewerElementsIf( 591 [=](const LegalityQuery &Query) { return notValidElt(Query, 1); }, 592 scalarize(1)) 593 .clampScalar(BigTyIdx, S32, S512) 594 .widenScalarIf( 595 [=](const LegalityQuery &Query) { 596 const LLT &Ty = Query.Types[BigTyIdx]; 597 return !isPowerOf2_32(Ty.getSizeInBits()) && 598 Ty.getSizeInBits() % 16 != 0; 599 }, 600 [=](const LegalityQuery &Query) { 601 // Pick the next power of 2, or a multiple of 64 over 128. 602 // Whichever is smaller. 603 const LLT &Ty = Query.Types[BigTyIdx]; 604 unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1); 605 if (NewSizeInBits >= 256) { 606 unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1); 607 if (RoundedTo < NewSizeInBits) 608 NewSizeInBits = RoundedTo; 609 } 610 return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits)); 611 }) 612 .legalIf([=](const LegalityQuery &Query) { 613 const LLT &BigTy = Query.Types[BigTyIdx]; 614 const LLT &LitTy = Query.Types[LitTyIdx]; 615 616 if (BigTy.isVector() && BigTy.getSizeInBits() < 32) 617 return false; 618 if (LitTy.isVector() && LitTy.getSizeInBits() < 32) 619 return false; 620 621 return BigTy.getSizeInBits() % 16 == 0 && 622 LitTy.getSizeInBits() % 16 == 0 && 623 BigTy.getSizeInBits() <= 512; 624 }) 625 // Any vectors left are the wrong size. Scalarize them. 626 .scalarize(0) 627 .scalarize(1); 628 } 629 630 computeTables(); 631 verify(*ST.getInstrInfo()); 632 } 633 634 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI, 635 MachineRegisterInfo &MRI, 636 MachineIRBuilder &MIRBuilder, 637 GISelChangeObserver &Observer) const { 638 switch (MI.getOpcode()) { 639 case TargetOpcode::G_ADDRSPACE_CAST: 640 return legalizeAddrSpaceCast(MI, MRI, MIRBuilder); 641 default: 642 return false; 643 } 644 645 llvm_unreachable("expected switch to return"); 646 } 647 648 unsigned AMDGPULegalizerInfo::getSegmentAperture( 649 unsigned AS, 650 MachineRegisterInfo &MRI, 651 MachineIRBuilder &MIRBuilder) const { 652 MachineFunction &MF = MIRBuilder.getMF(); 653 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 654 const LLT S32 = LLT::scalar(32); 655 656 if (ST.hasApertureRegs()) { 657 // FIXME: Use inline constants (src_{shared, private}_base) instead of 658 // getreg. 659 unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ? 660 AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE : 661 AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE; 662 unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ? 663 AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE : 664 AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE; 665 unsigned Encoding = 666 AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ | 667 Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ | 668 WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_; 669 670 unsigned ShiftAmt = MRI.createGenericVirtualRegister(S32); 671 unsigned ApertureReg = MRI.createGenericVirtualRegister(S32); 672 unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass); 673 674 MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32) 675 .addDef(GetReg) 676 .addImm(Encoding); 677 MRI.setType(GetReg, S32); 678 679 MIRBuilder.buildConstant(ShiftAmt, WidthM1 + 1); 680 MIRBuilder.buildInstr(TargetOpcode::G_SHL) 681 .addDef(ApertureReg) 682 .addUse(GetReg) 683 .addUse(ShiftAmt); 684 685 return ApertureReg; 686 } 687 688 unsigned QueuePtr = MRI.createGenericVirtualRegister( 689 LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64)); 690 691 // FIXME: Placeholder until we can track the input registers. 692 MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef); 693 694 // Offset into amd_queue_t for group_segment_aperture_base_hi / 695 // private_segment_aperture_base_hi. 696 uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44; 697 698 // FIXME: Don't use undef 699 Value *V = UndefValue::get(PointerType::get( 700 Type::getInt8Ty(MF.getFunction().getContext()), 701 AMDGPUAS::CONSTANT_ADDRESS)); 702 703 MachinePointerInfo PtrInfo(V, StructOffset); 704 MachineMemOperand *MMO = MF.getMachineMemOperand( 705 PtrInfo, 706 MachineMemOperand::MOLoad | 707 MachineMemOperand::MODereferenceable | 708 MachineMemOperand::MOInvariant, 709 4, 710 MinAlign(64, StructOffset)); 711 712 unsigned LoadResult = MRI.createGenericVirtualRegister(S32); 713 unsigned LoadAddr = AMDGPU::NoRegister; 714 715 MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset); 716 MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO); 717 return LoadResult; 718 } 719 720 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast( 721 MachineInstr &MI, MachineRegisterInfo &MRI, 722 MachineIRBuilder &MIRBuilder) const { 723 MachineFunction &MF = MIRBuilder.getMF(); 724 725 MIRBuilder.setInstr(MI); 726 727 unsigned Dst = MI.getOperand(0).getReg(); 728 unsigned Src = MI.getOperand(1).getReg(); 729 730 LLT DstTy = MRI.getType(Dst); 731 LLT SrcTy = MRI.getType(Src); 732 unsigned DestAS = DstTy.getAddressSpace(); 733 unsigned SrcAS = SrcTy.getAddressSpace(); 734 735 // TODO: Avoid reloading from the queue ptr for each cast, or at least each 736 // vector element. 737 assert(!DstTy.isVector()); 738 739 const AMDGPUTargetMachine &TM 740 = static_cast<const AMDGPUTargetMachine &>(MF.getTarget()); 741 742 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 743 if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) { 744 MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST)); 745 return true; 746 } 747 748 if (SrcAS == AMDGPUAS::FLAT_ADDRESS) { 749 assert(DestAS == AMDGPUAS::LOCAL_ADDRESS || 750 DestAS == AMDGPUAS::PRIVATE_ADDRESS); 751 unsigned NullVal = TM.getNullPointerValue(DestAS); 752 753 unsigned SegmentNullReg = MRI.createGenericVirtualRegister(DstTy); 754 unsigned FlatNullReg = MRI.createGenericVirtualRegister(SrcTy); 755 756 MIRBuilder.buildConstant(SegmentNullReg, NullVal); 757 MIRBuilder.buildConstant(FlatNullReg, 0); 758 759 unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy); 760 761 // Extract low 32-bits of the pointer. 762 MIRBuilder.buildExtract(PtrLo32, Src, 0); 763 764 unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1)); 765 MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNullReg); 766 MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNullReg); 767 768 MI.eraseFromParent(); 769 return true; 770 } 771 772 assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS || 773 SrcAS == AMDGPUAS::PRIVATE_ADDRESS); 774 775 unsigned FlatNullReg = MRI.createGenericVirtualRegister(DstTy); 776 unsigned SegmentNullReg = MRI.createGenericVirtualRegister(SrcTy); 777 MIRBuilder.buildConstant(SegmentNullReg, TM.getNullPointerValue(SrcAS)); 778 MIRBuilder.buildConstant(FlatNullReg, TM.getNullPointerValue(DestAS)); 779 780 unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder); 781 782 unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1)); 783 MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNullReg); 784 785 unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy); 786 787 // Coerce the type of the low half of the result so we can use merge_values. 788 unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32)); 789 MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 790 .addDef(SrcAsInt) 791 .addUse(Src); 792 793 // TODO: Should we allow mismatched types but matching sizes in merges to 794 // avoid the ptrtoint? 795 MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg}); 796 MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNullReg); 797 798 MI.eraseFromParent(); 799 return true; 800 } 801