1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "clang/Support/RISCVVIntrinsicUtils.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/ADT/Optional.h" 12 #include "llvm/ADT/SmallSet.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringMap.h" 15 #include "llvm/ADT/StringSet.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/raw_ostream.h" 18 #include <numeric> 19 #include <set> 20 #include <unordered_map> 21 22 using namespace llvm; 23 24 namespace clang { 25 namespace RISCV { 26 27 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor( 28 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector); 29 const PrototypeDescriptor PrototypeDescriptor::VL = 30 PrototypeDescriptor(BaseTypeModifier::SizeT); 31 const PrototypeDescriptor PrototypeDescriptor::Vector = 32 PrototypeDescriptor(BaseTypeModifier::Vector); 33 34 // Concat BasicType, LMUL and Proto as key 35 static std::unordered_map<uint64_t, RVVType> LegalTypes; 36 static std::set<uint64_t> IllegalTypes; 37 38 //===----------------------------------------------------------------------===// 39 // Type implementation 40 //===----------------------------------------------------------------------===// 41 42 LMULType::LMULType(int NewLog2LMUL) { 43 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 44 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 45 Log2LMUL = NewLog2LMUL; 46 } 47 48 std::string LMULType::str() const { 49 if (Log2LMUL < 0) 50 return "mf" + utostr(1ULL << (-Log2LMUL)); 51 return "m" + utostr(1ULL << Log2LMUL); 52 } 53 54 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 55 int Log2ScaleResult = 0; 56 switch (ElementBitwidth) { 57 default: 58 break; 59 case 8: 60 Log2ScaleResult = Log2LMUL + 3; 61 break; 62 case 16: 63 Log2ScaleResult = Log2LMUL + 2; 64 break; 65 case 32: 66 Log2ScaleResult = Log2LMUL + 1; 67 break; 68 case 64: 69 Log2ScaleResult = Log2LMUL; 70 break; 71 } 72 // Illegal vscale result would be less than 1 73 if (Log2ScaleResult < 0) 74 return llvm::None; 75 return 1 << Log2ScaleResult; 76 } 77 78 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 79 80 RVVType::RVVType(BasicType BT, int Log2LMUL, 81 const PrototypeDescriptor &prototype) 82 : BT(BT), LMUL(LMULType(Log2LMUL)) { 83 applyBasicType(); 84 applyModifier(prototype); 85 Valid = verifyType(); 86 if (Valid) { 87 initBuiltinStr(); 88 initTypeStr(); 89 if (isVector()) { 90 initClangBuiltinStr(); 91 } 92 } 93 } 94 95 // clang-format off 96 // boolean type are encoded the ratio of n (SEW/LMUL) 97 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 98 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 99 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 100 101 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 102 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 103 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 104 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 105 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 106 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 107 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 108 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 109 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 110 // clang-format on 111 112 bool RVVType::verifyType() const { 113 if (ScalarType == Invalid) 114 return false; 115 if (isScalar()) 116 return true; 117 if (!Scale.hasValue()) 118 return false; 119 if (isFloat() && ElementBitwidth == 8) 120 return false; 121 unsigned V = Scale.getValue(); 122 switch (ElementBitwidth) { 123 case 1: 124 case 8: 125 // Check Scale is 1,2,4,8,16,32,64 126 return (V <= 64 && isPowerOf2_32(V)); 127 case 16: 128 // Check Scale is 1,2,4,8,16,32 129 return (V <= 32 && isPowerOf2_32(V)); 130 case 32: 131 // Check Scale is 1,2,4,8,16 132 return (V <= 16 && isPowerOf2_32(V)); 133 case 64: 134 // Check Scale is 1,2,4,8 135 return (V <= 8 && isPowerOf2_32(V)); 136 } 137 return false; 138 } 139 140 void RVVType::initBuiltinStr() { 141 assert(isValid() && "RVVType is invalid"); 142 switch (ScalarType) { 143 case ScalarTypeKind::Void: 144 BuiltinStr = "v"; 145 return; 146 case ScalarTypeKind::Size_t: 147 BuiltinStr = "z"; 148 if (IsImmediate) 149 BuiltinStr = "I" + BuiltinStr; 150 if (IsPointer) 151 BuiltinStr += "*"; 152 return; 153 case ScalarTypeKind::Ptrdiff_t: 154 BuiltinStr = "Y"; 155 return; 156 case ScalarTypeKind::UnsignedLong: 157 BuiltinStr = "ULi"; 158 return; 159 case ScalarTypeKind::SignedLong: 160 BuiltinStr = "Li"; 161 return; 162 case ScalarTypeKind::Boolean: 163 assert(ElementBitwidth == 1); 164 BuiltinStr += "b"; 165 break; 166 case ScalarTypeKind::SignedInteger: 167 case ScalarTypeKind::UnsignedInteger: 168 switch (ElementBitwidth) { 169 case 8: 170 BuiltinStr += "c"; 171 break; 172 case 16: 173 BuiltinStr += "s"; 174 break; 175 case 32: 176 BuiltinStr += "i"; 177 break; 178 case 64: 179 BuiltinStr += "Wi"; 180 break; 181 default: 182 llvm_unreachable("Unhandled ElementBitwidth!"); 183 } 184 if (isSignedInteger()) 185 BuiltinStr = "S" + BuiltinStr; 186 else 187 BuiltinStr = "U" + BuiltinStr; 188 break; 189 case ScalarTypeKind::Float: 190 switch (ElementBitwidth) { 191 case 16: 192 BuiltinStr += "x"; 193 break; 194 case 32: 195 BuiltinStr += "f"; 196 break; 197 case 64: 198 BuiltinStr += "d"; 199 break; 200 default: 201 llvm_unreachable("Unhandled ElementBitwidth!"); 202 } 203 break; 204 default: 205 llvm_unreachable("ScalarType is invalid!"); 206 } 207 if (IsImmediate) 208 BuiltinStr = "I" + BuiltinStr; 209 if (isScalar()) { 210 if (IsConstant) 211 BuiltinStr += "C"; 212 if (IsPointer) 213 BuiltinStr += "*"; 214 return; 215 } 216 BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; 217 // Pointer to vector types. Defined for segment load intrinsics. 218 // segment load intrinsics have pointer type arguments to store the loaded 219 // vector values. 220 if (IsPointer) 221 BuiltinStr += "*"; 222 } 223 224 void RVVType::initClangBuiltinStr() { 225 assert(isValid() && "RVVType is invalid"); 226 assert(isVector() && "Handle Vector type only"); 227 228 ClangBuiltinStr = "__rvv_"; 229 switch (ScalarType) { 230 case ScalarTypeKind::Boolean: 231 ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; 232 return; 233 case ScalarTypeKind::Float: 234 ClangBuiltinStr += "float"; 235 break; 236 case ScalarTypeKind::SignedInteger: 237 ClangBuiltinStr += "int"; 238 break; 239 case ScalarTypeKind::UnsignedInteger: 240 ClangBuiltinStr += "uint"; 241 break; 242 default: 243 llvm_unreachable("ScalarTypeKind is invalid"); 244 } 245 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 246 } 247 248 void RVVType::initTypeStr() { 249 assert(isValid() && "RVVType is invalid"); 250 251 if (IsConstant) 252 Str += "const "; 253 254 auto getTypeString = [&](StringRef TypeStr) { 255 if (isScalar()) 256 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 257 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 258 .str(); 259 }; 260 261 switch (ScalarType) { 262 case ScalarTypeKind::Void: 263 Str = "void"; 264 return; 265 case ScalarTypeKind::Size_t: 266 Str = "size_t"; 267 if (IsPointer) 268 Str += " *"; 269 return; 270 case ScalarTypeKind::Ptrdiff_t: 271 Str = "ptrdiff_t"; 272 return; 273 case ScalarTypeKind::UnsignedLong: 274 Str = "unsigned long"; 275 return; 276 case ScalarTypeKind::SignedLong: 277 Str = "long"; 278 return; 279 case ScalarTypeKind::Boolean: 280 if (isScalar()) 281 Str += "bool"; 282 else 283 // Vector bool is special case, the formulate is 284 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 285 Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; 286 break; 287 case ScalarTypeKind::Float: 288 if (isScalar()) { 289 if (ElementBitwidth == 64) 290 Str += "double"; 291 else if (ElementBitwidth == 32) 292 Str += "float"; 293 else if (ElementBitwidth == 16) 294 Str += "_Float16"; 295 else 296 llvm_unreachable("Unhandled floating type."); 297 } else 298 Str += getTypeString("float"); 299 break; 300 case ScalarTypeKind::SignedInteger: 301 Str += getTypeString("int"); 302 break; 303 case ScalarTypeKind::UnsignedInteger: 304 Str += getTypeString("uint"); 305 break; 306 default: 307 llvm_unreachable("ScalarType is invalid!"); 308 } 309 if (IsPointer) 310 Str += " *"; 311 } 312 313 void RVVType::initShortStr() { 314 switch (ScalarType) { 315 case ScalarTypeKind::Boolean: 316 assert(isVector()); 317 ShortStr = "b" + utostr(64 / Scale.getValue()); 318 return; 319 case ScalarTypeKind::Float: 320 ShortStr = "f" + utostr(ElementBitwidth); 321 break; 322 case ScalarTypeKind::SignedInteger: 323 ShortStr = "i" + utostr(ElementBitwidth); 324 break; 325 case ScalarTypeKind::UnsignedInteger: 326 ShortStr = "u" + utostr(ElementBitwidth); 327 break; 328 default: 329 llvm_unreachable("Unhandled case!"); 330 } 331 if (isVector()) 332 ShortStr += LMUL.str(); 333 } 334 335 void RVVType::applyBasicType() { 336 switch (BT) { 337 case BasicType::Int8: 338 ElementBitwidth = 8; 339 ScalarType = ScalarTypeKind::SignedInteger; 340 break; 341 case BasicType::Int16: 342 ElementBitwidth = 16; 343 ScalarType = ScalarTypeKind::SignedInteger; 344 break; 345 case BasicType::Int32: 346 ElementBitwidth = 32; 347 ScalarType = ScalarTypeKind::SignedInteger; 348 break; 349 case BasicType::Int64: 350 ElementBitwidth = 64; 351 ScalarType = ScalarTypeKind::SignedInteger; 352 break; 353 case BasicType::Float16: 354 ElementBitwidth = 16; 355 ScalarType = ScalarTypeKind::Float; 356 break; 357 case BasicType::Float32: 358 ElementBitwidth = 32; 359 ScalarType = ScalarTypeKind::Float; 360 break; 361 case BasicType::Float64: 362 ElementBitwidth = 64; 363 ScalarType = ScalarTypeKind::Float; 364 break; 365 default: 366 llvm_unreachable("Unhandled type code!"); 367 } 368 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 369 } 370 371 Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor( 372 llvm::StringRef PrototypeDescriptorStr) { 373 PrototypeDescriptor PD; 374 BaseTypeModifier PT = BaseTypeModifier::Invalid; 375 VectorTypeModifier VTM = VectorTypeModifier::NoModifier; 376 377 if (PrototypeDescriptorStr.empty()) 378 return PD; 379 380 // Handle base type modifier 381 auto PType = PrototypeDescriptorStr.back(); 382 switch (PType) { 383 case 'e': 384 PT = BaseTypeModifier::Scalar; 385 break; 386 case 'v': 387 PT = BaseTypeModifier::Vector; 388 break; 389 case 'w': 390 PT = BaseTypeModifier::Vector; 391 VTM = VectorTypeModifier::Widening2XVector; 392 break; 393 case 'q': 394 PT = BaseTypeModifier::Vector; 395 VTM = VectorTypeModifier::Widening4XVector; 396 break; 397 case 'o': 398 PT = BaseTypeModifier::Vector; 399 VTM = VectorTypeModifier::Widening8XVector; 400 break; 401 case 'm': 402 PT = BaseTypeModifier::Vector; 403 VTM = VectorTypeModifier::MaskVector; 404 break; 405 case '0': 406 PT = BaseTypeModifier::Void; 407 break; 408 case 'z': 409 PT = BaseTypeModifier::SizeT; 410 break; 411 case 't': 412 PT = BaseTypeModifier::Ptrdiff; 413 break; 414 case 'u': 415 PT = BaseTypeModifier::UnsignedLong; 416 break; 417 case 'l': 418 PT = BaseTypeModifier::SignedLong; 419 break; 420 default: 421 llvm_unreachable("Illegal primitive type transformers!"); 422 } 423 PD.PT = static_cast<uint8_t>(PT); 424 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back(); 425 426 // Compute the vector type transformers, it can only appear one time. 427 if (PrototypeDescriptorStr.startswith("(")) { 428 assert(VTM == VectorTypeModifier::NoModifier && 429 "VectorTypeModifier should only have one modifier"); 430 size_t Idx = PrototypeDescriptorStr.find(')'); 431 assert(Idx != StringRef::npos); 432 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx); 433 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1); 434 assert(!PrototypeDescriptorStr.contains('(') && 435 "Only allow one vector type modifier"); 436 437 auto ComplexTT = ComplexType.split(":"); 438 if (ComplexTT.first == "Log2EEW") { 439 uint32_t Log2EEW; 440 if (ComplexTT.second.getAsInteger(10, Log2EEW)) { 441 llvm_unreachable("Invalid Log2EEW value!"); 442 return None; 443 } 444 switch (Log2EEW) { 445 case 3: 446 VTM = VectorTypeModifier::Log2EEW3; 447 break; 448 case 4: 449 VTM = VectorTypeModifier::Log2EEW4; 450 break; 451 case 5: 452 VTM = VectorTypeModifier::Log2EEW5; 453 break; 454 case 6: 455 VTM = VectorTypeModifier::Log2EEW6; 456 break; 457 default: 458 llvm_unreachable("Invalid Log2EEW value, should be [3-6]"); 459 return None; 460 } 461 } else if (ComplexTT.first == "FixedSEW") { 462 uint32_t NewSEW; 463 if (ComplexTT.second.getAsInteger(10, NewSEW)) { 464 llvm_unreachable("Invalid FixedSEW value!"); 465 return None; 466 } 467 switch (NewSEW) { 468 case 8: 469 VTM = VectorTypeModifier::FixedSEW8; 470 break; 471 case 16: 472 VTM = VectorTypeModifier::FixedSEW16; 473 break; 474 case 32: 475 VTM = VectorTypeModifier::FixedSEW32; 476 break; 477 case 64: 478 VTM = VectorTypeModifier::FixedSEW64; 479 break; 480 default: 481 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64"); 482 return None; 483 } 484 } else if (ComplexTT.first == "LFixedLog2LMUL") { 485 int32_t Log2LMUL; 486 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 487 llvm_unreachable("Invalid LFixedLog2LMUL value!"); 488 return None; 489 } 490 switch (Log2LMUL) { 491 case -3: 492 VTM = VectorTypeModifier::LFixedLog2LMULN3; 493 break; 494 case -2: 495 VTM = VectorTypeModifier::LFixedLog2LMULN2; 496 break; 497 case -1: 498 VTM = VectorTypeModifier::LFixedLog2LMULN1; 499 break; 500 case 0: 501 VTM = VectorTypeModifier::LFixedLog2LMUL0; 502 break; 503 case 1: 504 VTM = VectorTypeModifier::LFixedLog2LMUL1; 505 break; 506 case 2: 507 VTM = VectorTypeModifier::LFixedLog2LMUL2; 508 break; 509 case 3: 510 VTM = VectorTypeModifier::LFixedLog2LMUL3; 511 break; 512 default: 513 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 514 return None; 515 } 516 } else if (ComplexTT.first == "SFixedLog2LMUL") { 517 int32_t Log2LMUL; 518 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) { 519 llvm_unreachable("Invalid SFixedLog2LMUL value!"); 520 return None; 521 } 522 switch (Log2LMUL) { 523 case -3: 524 VTM = VectorTypeModifier::SFixedLog2LMULN3; 525 break; 526 case -2: 527 VTM = VectorTypeModifier::SFixedLog2LMULN2; 528 break; 529 case -1: 530 VTM = VectorTypeModifier::SFixedLog2LMULN1; 531 break; 532 case 0: 533 VTM = VectorTypeModifier::SFixedLog2LMUL0; 534 break; 535 case 1: 536 VTM = VectorTypeModifier::SFixedLog2LMUL1; 537 break; 538 case 2: 539 VTM = VectorTypeModifier::SFixedLog2LMUL2; 540 break; 541 case 3: 542 VTM = VectorTypeModifier::SFixedLog2LMUL3; 543 break; 544 default: 545 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]"); 546 return None; 547 } 548 549 } else { 550 llvm_unreachable("Illegal complex type transformers!"); 551 } 552 } 553 PD.VTM = static_cast<uint8_t>(VTM); 554 555 // Compute the remain type transformers 556 TypeModifier TM = TypeModifier::NoModifier; 557 for (char I : PrototypeDescriptorStr) { 558 switch (I) { 559 case 'P': 560 if ((TM & TypeModifier::Const) == TypeModifier::Const) 561 llvm_unreachable("'P' transformer cannot be used after 'C'"); 562 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer) 563 llvm_unreachable("'P' transformer cannot be used twice"); 564 TM |= TypeModifier::Pointer; 565 break; 566 case 'C': 567 TM |= TypeModifier::Const; 568 break; 569 case 'K': 570 TM |= TypeModifier::Immediate; 571 break; 572 case 'U': 573 TM |= TypeModifier::UnsignedInteger; 574 break; 575 case 'I': 576 TM |= TypeModifier::SignedInteger; 577 break; 578 case 'F': 579 TM |= TypeModifier::Float; 580 break; 581 case 'S': 582 TM |= TypeModifier::LMUL1; 583 break; 584 default: 585 llvm_unreachable("Illegal non-primitive type transformer!"); 586 } 587 } 588 PD.TM = static_cast<uint8_t>(TM); 589 590 return PD; 591 } 592 593 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) { 594 // Handle primitive type transformer 595 switch (static_cast<BaseTypeModifier>(Transformer.PT)) { 596 case BaseTypeModifier::Scalar: 597 Scale = 0; 598 break; 599 case BaseTypeModifier::Vector: 600 Scale = LMUL.getScale(ElementBitwidth); 601 break; 602 case BaseTypeModifier::Void: 603 ScalarType = ScalarTypeKind::Void; 604 break; 605 case BaseTypeModifier::SizeT: 606 ScalarType = ScalarTypeKind::Size_t; 607 break; 608 case BaseTypeModifier::Ptrdiff: 609 ScalarType = ScalarTypeKind::Ptrdiff_t; 610 break; 611 case BaseTypeModifier::UnsignedLong: 612 ScalarType = ScalarTypeKind::UnsignedLong; 613 break; 614 case BaseTypeModifier::SignedLong: 615 ScalarType = ScalarTypeKind::SignedLong; 616 break; 617 case BaseTypeModifier::Invalid: 618 ScalarType = ScalarTypeKind::Invalid; 619 return; 620 } 621 622 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) { 623 case VectorTypeModifier::Widening2XVector: 624 ElementBitwidth *= 2; 625 LMUL.MulLog2LMUL(1); 626 Scale = LMUL.getScale(ElementBitwidth); 627 break; 628 case VectorTypeModifier::Widening4XVector: 629 ElementBitwidth *= 4; 630 LMUL.MulLog2LMUL(2); 631 Scale = LMUL.getScale(ElementBitwidth); 632 break; 633 case VectorTypeModifier::Widening8XVector: 634 ElementBitwidth *= 8; 635 LMUL.MulLog2LMUL(3); 636 Scale = LMUL.getScale(ElementBitwidth); 637 break; 638 case VectorTypeModifier::MaskVector: 639 ScalarType = ScalarTypeKind::Boolean; 640 Scale = LMUL.getScale(ElementBitwidth); 641 ElementBitwidth = 1; 642 break; 643 case VectorTypeModifier::Log2EEW3: 644 applyLog2EEW(3); 645 break; 646 case VectorTypeModifier::Log2EEW4: 647 applyLog2EEW(4); 648 break; 649 case VectorTypeModifier::Log2EEW5: 650 applyLog2EEW(5); 651 break; 652 case VectorTypeModifier::Log2EEW6: 653 applyLog2EEW(6); 654 break; 655 case VectorTypeModifier::FixedSEW8: 656 applyFixedSEW(8); 657 break; 658 case VectorTypeModifier::FixedSEW16: 659 applyFixedSEW(16); 660 break; 661 case VectorTypeModifier::FixedSEW32: 662 applyFixedSEW(32); 663 break; 664 case VectorTypeModifier::FixedSEW64: 665 applyFixedSEW(64); 666 break; 667 case VectorTypeModifier::LFixedLog2LMULN3: 668 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan); 669 break; 670 case VectorTypeModifier::LFixedLog2LMULN2: 671 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan); 672 break; 673 case VectorTypeModifier::LFixedLog2LMULN1: 674 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan); 675 break; 676 case VectorTypeModifier::LFixedLog2LMUL0: 677 applyFixedLog2LMUL(0, FixedLMULType::LargerThan); 678 break; 679 case VectorTypeModifier::LFixedLog2LMUL1: 680 applyFixedLog2LMUL(1, FixedLMULType::LargerThan); 681 break; 682 case VectorTypeModifier::LFixedLog2LMUL2: 683 applyFixedLog2LMUL(2, FixedLMULType::LargerThan); 684 break; 685 case VectorTypeModifier::LFixedLog2LMUL3: 686 applyFixedLog2LMUL(3, FixedLMULType::LargerThan); 687 break; 688 case VectorTypeModifier::SFixedLog2LMULN3: 689 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan); 690 break; 691 case VectorTypeModifier::SFixedLog2LMULN2: 692 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan); 693 break; 694 case VectorTypeModifier::SFixedLog2LMULN1: 695 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan); 696 break; 697 case VectorTypeModifier::SFixedLog2LMUL0: 698 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan); 699 break; 700 case VectorTypeModifier::SFixedLog2LMUL1: 701 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan); 702 break; 703 case VectorTypeModifier::SFixedLog2LMUL2: 704 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan); 705 break; 706 case VectorTypeModifier::SFixedLog2LMUL3: 707 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan); 708 break; 709 case VectorTypeModifier::NoModifier: 710 break; 711 } 712 713 for (unsigned TypeModifierMaskShift = 0; 714 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset); 715 ++TypeModifierMaskShift) { 716 unsigned TypeModifierMask = 1 << TypeModifierMaskShift; 717 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) != 718 TypeModifierMask) 719 continue; 720 switch (static_cast<TypeModifier>(TypeModifierMask)) { 721 case TypeModifier::Pointer: 722 IsPointer = true; 723 break; 724 case TypeModifier::Const: 725 IsConstant = true; 726 break; 727 case TypeModifier::Immediate: 728 IsImmediate = true; 729 IsConstant = true; 730 break; 731 case TypeModifier::UnsignedInteger: 732 ScalarType = ScalarTypeKind::UnsignedInteger; 733 break; 734 case TypeModifier::SignedInteger: 735 ScalarType = ScalarTypeKind::SignedInteger; 736 break; 737 case TypeModifier::Float: 738 ScalarType = ScalarTypeKind::Float; 739 break; 740 case TypeModifier::LMUL1: 741 LMUL = LMULType(0); 742 // Update ElementBitwidth need to update Scale too. 743 Scale = LMUL.getScale(ElementBitwidth); 744 break; 745 default: 746 llvm_unreachable("Unknown type modifier mask!"); 747 } 748 } 749 } 750 751 void RVVType::applyLog2EEW(unsigned Log2EEW) { 752 // update new elmul = (eew/sew) * lmul 753 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 754 // update new eew 755 ElementBitwidth = 1 << Log2EEW; 756 ScalarType = ScalarTypeKind::SignedInteger; 757 Scale = LMUL.getScale(ElementBitwidth); 758 } 759 760 void RVVType::applyFixedSEW(unsigned NewSEW) { 761 // Set invalid type if src and dst SEW are same. 762 if (ElementBitwidth == NewSEW) { 763 ScalarType = ScalarTypeKind::Invalid; 764 return; 765 } 766 // Update new SEW 767 ElementBitwidth = NewSEW; 768 Scale = LMUL.getScale(ElementBitwidth); 769 } 770 771 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) { 772 switch (Type) { 773 case FixedLMULType::LargerThan: 774 if (Log2LMUL < LMUL.Log2LMUL) { 775 ScalarType = ScalarTypeKind::Invalid; 776 return; 777 } 778 break; 779 case FixedLMULType::SmallerThan: 780 if (Log2LMUL > LMUL.Log2LMUL) { 781 ScalarType = ScalarTypeKind::Invalid; 782 return; 783 } 784 break; 785 } 786 787 // Update new LMUL 788 LMUL = LMULType(Log2LMUL); 789 Scale = LMUL.getScale(ElementBitwidth); 790 } 791 792 Optional<RVVTypes> 793 RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 794 ArrayRef<PrototypeDescriptor> Prototype) { 795 // LMUL x NF must be less than or equal to 8. 796 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) 797 return llvm::None; 798 799 RVVTypes Types; 800 for (const PrototypeDescriptor &Proto : Prototype) { 801 auto T = computeType(BT, Log2LMUL, Proto); 802 if (!T.hasValue()) 803 return llvm::None; 804 // Record legal type index 805 Types.push_back(T.getValue()); 806 } 807 return Types; 808 } 809 810 // Compute the hash value of RVVType, used for cache the result of computeType. 811 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL, 812 PrototypeDescriptor Proto) { 813 // Layout of hash value: 814 // 0 8 16 24 32 40 815 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM | 816 assert(Log2LMUL >= -3 && Log2LMUL <= 3); 817 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 | 818 ((uint64_t)(Proto.PT & 0xff) << 16) | 819 ((uint64_t)(Proto.TM & 0xff) << 24) | 820 ((uint64_t)(Proto.VTM & 0xff) << 32); 821 } 822 823 Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL, 824 PrototypeDescriptor Proto) { 825 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto); 826 // Search first 827 auto It = LegalTypes.find(Idx); 828 if (It != LegalTypes.end()) 829 return &(It->second); 830 831 if (IllegalTypes.count(Idx)) 832 return llvm::None; 833 834 // Compute type and record the result. 835 RVVType T(BT, Log2LMUL, Proto); 836 if (T.isValid()) { 837 // Record legal type index and value. 838 LegalTypes.insert({Idx, T}); 839 return &(LegalTypes[Idx]); 840 } 841 // Record illegal type index. 842 IllegalTypes.insert(Idx); 843 return llvm::None; 844 } 845 846 //===----------------------------------------------------------------------===// 847 // RVVIntrinsic implementation 848 //===----------------------------------------------------------------------===// 849 RVVIntrinsic::RVVIntrinsic( 850 StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, 851 StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, 852 bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, 853 bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen, 854 const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes, 855 const std::vector<StringRef> &RequiredFeatures, unsigned NF) 856 : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme), 857 HasUnMaskedOverloaded(HasUnMaskedOverloaded), 858 HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()), 859 NF(NF) { 860 861 // Init BuiltinName, Name and OverloadedName 862 BuiltinName = NewName.str(); 863 Name = BuiltinName; 864 if (NewOverloadedName.empty()) 865 OverloadedName = NewName.split("_").first.str(); 866 else 867 OverloadedName = NewOverloadedName.str(); 868 if (!Suffix.empty()) 869 Name += "_" + Suffix.str(); 870 if (!OverloadedSuffix.empty()) 871 OverloadedName += "_" + OverloadedSuffix.str(); 872 if (IsMasked) { 873 BuiltinName += "_m"; 874 Name += "_m"; 875 } 876 877 // Init RISC-V extensions 878 for (const auto &T : OutInTypes) { 879 if (T->isFloatVector(16) || T->isFloat(16)) 880 RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh; 881 if (T->isFloatVector(32)) 882 RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32; 883 if (T->isFloatVector(64)) 884 RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64; 885 if (T->isVector(64)) 886 RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64; 887 } 888 for (auto Feature : RequiredFeatures) { 889 if (Feature == "RV64") 890 RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64; 891 // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64 892 // require V. 893 if (Feature == "FullMultiply" && 894 (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64)) 895 RISCVPredefinedMacros |= RISCVPredefinedMacro::V; 896 } 897 898 // Init OutputType and InputTypes 899 OutputType = OutInTypes[0]; 900 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 901 902 // IntrinsicTypes is unmasked TA version index. Need to update it 903 // if there is merge operand (It is always in first operand). 904 IntrinsicTypes = NewIntrinsicTypes; 905 if ((IsMasked && HasMaskedOffOperand) || 906 (!IsMasked && hasPassthruOperand())) { 907 for (auto &I : IntrinsicTypes) { 908 if (I >= 0) 909 I += NF; 910 } 911 } 912 } 913 914 std::string RVVIntrinsic::getBuiltinTypeStr() const { 915 std::string S; 916 S += OutputType->getBuiltinStr(); 917 for (const auto &T : InputTypes) { 918 S += T->getBuiltinStr(); 919 } 920 return S; 921 } 922 923 std::string RVVIntrinsic::getSuffixStr( 924 BasicType Type, int Log2LMUL, 925 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) { 926 SmallVector<std::string> SuffixStrs; 927 for (auto PD : PrototypeDescriptors) { 928 auto T = RVVType::computeType(Type, Log2LMUL, PD); 929 SuffixStrs.push_back(T.getValue()->getShortStr()); 930 } 931 return join(SuffixStrs, "_"); 932 } 933 934 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { 935 SmallVector<PrototypeDescriptor> PrototypeDescriptors; 936 const StringRef Primaries("evwqom0ztul"); 937 while (!Prototypes.empty()) { 938 size_t Idx = 0; 939 // Skip over complex prototype because it could contain primitive type 940 // character. 941 if (Prototypes[0] == '(') 942 Idx = Prototypes.find_first_of(')'); 943 Idx = Prototypes.find_first_of(Primaries, Idx); 944 assert(Idx != StringRef::npos); 945 auto PD = PrototypeDescriptor::parsePrototypeDescriptor( 946 Prototypes.slice(0, Idx + 1)); 947 if (!PD) 948 llvm_unreachable("Error during parsing prototype."); 949 PrototypeDescriptors.push_back(*PD); 950 Prototypes = Prototypes.drop_front(Idx + 1); 951 } 952 return PrototypeDescriptors; 953 } 954 955 } // end namespace RISCV 956 } // end namespace clang 957