1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This tablegen backend is responsible for emitting riscv_vector.h which 10 // includes a declaration and definition of each intrinsic functions specified 11 // in https://github.com/riscv/rvv-intrinsic-doc. 12 // 13 // See also the documentation in include/clang/Basic/riscv_vector.td. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallSet.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/StringMap.h" 21 #include "llvm/ADT/StringSet.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/TableGen/Error.h" 24 #include "llvm/TableGen/Record.h" 25 #include <numeric> 26 27 using namespace llvm; 28 using BasicType = char; 29 using VScaleVal = Optional<unsigned>; 30 31 namespace { 32 33 // Exponential LMUL 34 struct LMULType { 35 int Log2LMUL; 36 LMULType(int Log2LMUL); 37 // Return the C/C++ string representation of LMUL 38 std::string str() const; 39 Optional<unsigned> getScale(unsigned ElementBitwidth) const; 40 void MulLog2LMUL(int Log2LMUL); 41 LMULType &operator*=(uint32_t RHS); 42 }; 43 44 // This class is compact representation of a valid and invalid RVVType. 45 class RVVType { 46 enum ScalarTypeKind : uint32_t { 47 Void, 48 Size_t, 49 Ptrdiff_t, 50 UnsignedLong, 51 SignedLong, 52 Boolean, 53 SignedInteger, 54 UnsignedInteger, 55 Float, 56 Invalid, 57 }; 58 BasicType BT; 59 ScalarTypeKind ScalarType = Invalid; 60 LMULType LMUL; 61 bool IsPointer = false; 62 // IsConstant indices are "int", but have the constant expression. 63 bool IsImmediate = false; 64 // Const qualifier for pointer to const object or object of const type. 65 bool IsConstant = false; 66 unsigned ElementBitwidth = 0; 67 VScaleVal Scale = 0; 68 bool Valid; 69 70 std::string BuiltinStr; 71 std::string ClangBuiltinStr; 72 std::string Str; 73 std::string ShortStr; 74 75 public: 76 RVVType() : RVVType(BasicType(), 0, StringRef()) {} 77 RVVType(BasicType BT, int Log2LMUL, StringRef prototype); 78 79 // Return the string representation of a type, which is an encoded string for 80 // passing to the BUILTIN() macro in Builtins.def. 81 const std::string &getBuiltinStr() const { return BuiltinStr; } 82 83 // Return the clang builtin type for RVV vector type which are used in the 84 // riscv_vector.h header file. 85 const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } 86 87 // Return the C/C++ string representation of a type for use in the 88 // riscv_vector.h header file. 89 const std::string &getTypeStr() const { return Str; } 90 91 // Return the short name of a type for C/C++ name suffix. 92 const std::string &getShortStr() { 93 // Not all types are used in short name, so compute the short name by 94 // demanded. 95 if (ShortStr.empty()) 96 initShortStr(); 97 return ShortStr; 98 } 99 100 bool isValid() const { return Valid; } 101 bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } 102 bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } 103 bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } 104 bool isSignedInteger() const { 105 return ScalarType == ScalarTypeKind::SignedInteger; 106 } 107 bool isFloatVector(unsigned Width) const { 108 return isVector() && isFloat() && ElementBitwidth == Width; 109 } 110 bool isFloat(unsigned Width) const { 111 return isFloat() && ElementBitwidth == Width; 112 } 113 114 private: 115 // Verify RVV vector type and set Valid. 116 bool verifyType() const; 117 118 // Creates a type based on basic types of TypeRange 119 void applyBasicType(); 120 121 // Applies a prototype modifier to the current type. The result maybe an 122 // invalid type. 123 void applyModifier(StringRef prototype); 124 125 // Compute and record a string for legal type. 126 void initBuiltinStr(); 127 // Compute and record a builtin RVV vector type string. 128 void initClangBuiltinStr(); 129 // Compute and record a type string for used in the header. 130 void initTypeStr(); 131 // Compute and record a short name of a type for C/C++ name suffix. 132 void initShortStr(); 133 }; 134 135 using RVVTypePtr = RVVType *; 136 using RVVTypes = std::vector<RVVTypePtr>; 137 138 enum RISCVExtension : uint8_t { 139 Basic = 0, 140 F = 1 << 1, 141 D = 1 << 2, 142 Zfh = 1 << 3, 143 Zvlsseg = 1 << 4, 144 RV64 = 1 << 5, 145 }; 146 147 // TODO refactor RVVIntrinsic class design after support all intrinsic 148 // combination. This represents an instantiation of an intrinsic with a 149 // particular type and prototype 150 class RVVIntrinsic { 151 152 private: 153 std::string BuiltinName; // Builtin name 154 std::string Name; // C intrinsic name. 155 std::string MangledName; 156 std::string IRName; 157 bool IsMask; 158 bool HasVL; 159 bool HasPolicy; 160 bool HasNoMaskedOverloaded; 161 bool HasAutoDef; // There is automiatic definition in header 162 std::string ManualCodegen; 163 RVVTypePtr OutputType; // Builtin output type 164 RVVTypes InputTypes; // Builtin input types 165 // The types we use to obtain the specific LLVM intrinsic. They are index of 166 // InputTypes. -1 means the return type. 167 std::vector<int64_t> IntrinsicTypes; 168 uint8_t RISCVExtensions = 0; 169 unsigned NF = 1; 170 171 public: 172 RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName, 173 StringRef MangledSuffix, StringRef IRName, bool IsMask, 174 bool HasMaskedOffOperand, bool HasVL, bool HasPolicy, 175 bool HasNoMaskedOverloaded, bool HasAutoDef, 176 StringRef ManualCodegen, const RVVTypes &Types, 177 const std::vector<int64_t> &IntrinsicTypes, 178 const std::vector<StringRef> &RequiredExtensions, unsigned NF); 179 ~RVVIntrinsic() = default; 180 181 StringRef getBuiltinName() const { return BuiltinName; } 182 StringRef getName() const { return Name; } 183 StringRef getMangledName() const { return MangledName; } 184 bool hasVL() const { return HasVL; } 185 bool hasPolicy() const { return HasPolicy; } 186 bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; } 187 bool hasManualCodegen() const { return !ManualCodegen.empty(); } 188 bool hasAutoDef() const { return HasAutoDef; } 189 bool isMask() const { return IsMask; } 190 StringRef getIRName() const { return IRName; } 191 StringRef getManualCodegen() const { return ManualCodegen; } 192 uint8_t getRISCVExtensions() const { return RISCVExtensions; } 193 unsigned getNF() const { return NF; } 194 const std::vector<int64_t> &getIntrinsicTypes() const { 195 return IntrinsicTypes; 196 } 197 198 // Return the type string for a BUILTIN() macro in Builtins.def. 199 std::string getBuiltinTypeStr() const; 200 201 // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should 202 // init the RVVIntrinsic ID and IntrinsicTypes. 203 void emitCodeGenSwitchBody(raw_ostream &o) const; 204 205 // Emit the macros for mapping C/C++ intrinsic function to builtin functions. 206 void emitIntrinsicFuncDef(raw_ostream &o) const; 207 208 // Emit the mangled function definition. 209 void emitMangledFuncDef(raw_ostream &o) const; 210 }; 211 212 class RVVEmitter { 213 private: 214 RecordKeeper &Records; 215 std::string HeaderCode; 216 // Concat BasicType, LMUL and Proto as key 217 StringMap<RVVType> LegalTypes; 218 StringSet<> IllegalTypes; 219 220 public: 221 RVVEmitter(RecordKeeper &R) : Records(R) {} 222 223 /// Emit riscv_vector.h 224 void createHeader(raw_ostream &o); 225 226 /// Emit all the __builtin prototypes and code needed by Sema. 227 void createBuiltins(raw_ostream &o); 228 229 /// Emit all the information needed to map builtin -> LLVM IR intrinsic. 230 void createCodeGen(raw_ostream &o); 231 232 std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); 233 234 private: 235 /// Create all intrinsics and add them to \p Out 236 void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out); 237 /// Create Headers and add them to \p Out 238 void createRVVHeaders(raw_ostream &OS); 239 /// Compute output and input types by applying different config (basic type 240 /// and LMUL with type transformers). It also record result of type in legal 241 /// or illegal set to avoid compute the same config again. The result maybe 242 /// have illegal RVVType. 243 Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 244 ArrayRef<std::string> PrototypeSeq); 245 Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto); 246 247 /// Emit Acrh predecessor definitions and body, assume the element of Defs are 248 /// sorted by extension. 249 void emitArchMacroAndBody( 250 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o, 251 std::function<void(raw_ostream &, const RVVIntrinsic &)>); 252 253 // Emit the architecture preprocessor definitions. Return true when emits 254 // non-empty string. 255 bool emitExtDefStr(uint8_t Extensions, raw_ostream &o); 256 // Slice Prototypes string into sub prototype string and process each sub 257 // prototype string individually in the Handler. 258 void parsePrototypes(StringRef Prototypes, 259 std::function<void(StringRef)> Handler); 260 }; 261 262 } // namespace 263 264 //===----------------------------------------------------------------------===// 265 // Type implementation 266 //===----------------------------------------------------------------------===// 267 268 LMULType::LMULType(int NewLog2LMUL) { 269 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 270 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 271 Log2LMUL = NewLog2LMUL; 272 } 273 274 std::string LMULType::str() const { 275 if (Log2LMUL < 0) 276 return "mf" + utostr(1ULL << (-Log2LMUL)); 277 return "m" + utostr(1ULL << Log2LMUL); 278 } 279 280 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 281 int Log2ScaleResult = 0; 282 switch (ElementBitwidth) { 283 default: 284 break; 285 case 8: 286 Log2ScaleResult = Log2LMUL + 3; 287 break; 288 case 16: 289 Log2ScaleResult = Log2LMUL + 2; 290 break; 291 case 32: 292 Log2ScaleResult = Log2LMUL + 1; 293 break; 294 case 64: 295 Log2ScaleResult = Log2LMUL; 296 break; 297 } 298 // Illegal vscale result would be less than 1 299 if (Log2ScaleResult < 0) 300 return None; 301 return 1 << Log2ScaleResult; 302 } 303 304 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 305 306 LMULType &LMULType::operator*=(uint32_t RHS) { 307 assert(isPowerOf2_32(RHS)); 308 this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); 309 return *this; 310 } 311 312 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) 313 : BT(BT), LMUL(LMULType(Log2LMUL)) { 314 applyBasicType(); 315 applyModifier(prototype); 316 Valid = verifyType(); 317 if (Valid) { 318 initBuiltinStr(); 319 initTypeStr(); 320 if (isVector()) { 321 initClangBuiltinStr(); 322 } 323 } 324 } 325 326 // clang-format off 327 // boolean type are encoded the ratio of n (SEW/LMUL) 328 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 329 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 330 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 331 332 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 333 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 334 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 335 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 336 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 337 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 338 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 339 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 340 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 341 // clang-format on 342 343 bool RVVType::verifyType() const { 344 if (ScalarType == Invalid) 345 return false; 346 if (isScalar()) 347 return true; 348 if (!Scale.hasValue()) 349 return false; 350 if (isFloat() && ElementBitwidth == 8) 351 return false; 352 unsigned V = Scale.getValue(); 353 switch (ElementBitwidth) { 354 case 1: 355 case 8: 356 // Check Scale is 1,2,4,8,16,32,64 357 return (V <= 64 && isPowerOf2_32(V)); 358 case 16: 359 // Check Scale is 1,2,4,8,16,32 360 return (V <= 32 && isPowerOf2_32(V)); 361 case 32: 362 // Check Scale is 1,2,4,8,16 363 return (V <= 16 && isPowerOf2_32(V)); 364 case 64: 365 // Check Scale is 1,2,4,8 366 return (V <= 8 && isPowerOf2_32(V)); 367 } 368 return false; 369 } 370 371 void RVVType::initBuiltinStr() { 372 assert(isValid() && "RVVType is invalid"); 373 switch (ScalarType) { 374 case ScalarTypeKind::Void: 375 BuiltinStr = "v"; 376 return; 377 case ScalarTypeKind::Size_t: 378 BuiltinStr = "z"; 379 if (IsImmediate) 380 BuiltinStr = "I" + BuiltinStr; 381 if (IsPointer) 382 BuiltinStr += "*"; 383 return; 384 case ScalarTypeKind::Ptrdiff_t: 385 BuiltinStr = "Y"; 386 return; 387 case ScalarTypeKind::UnsignedLong: 388 BuiltinStr = "ULi"; 389 return; 390 case ScalarTypeKind::SignedLong: 391 BuiltinStr = "Li"; 392 return; 393 case ScalarTypeKind::Boolean: 394 assert(ElementBitwidth == 1); 395 BuiltinStr += "b"; 396 break; 397 case ScalarTypeKind::SignedInteger: 398 case ScalarTypeKind::UnsignedInteger: 399 switch (ElementBitwidth) { 400 case 8: 401 BuiltinStr += "c"; 402 break; 403 case 16: 404 BuiltinStr += "s"; 405 break; 406 case 32: 407 BuiltinStr += "i"; 408 break; 409 case 64: 410 BuiltinStr += "Wi"; 411 break; 412 default: 413 llvm_unreachable("Unhandled ElementBitwidth!"); 414 } 415 if (isSignedInteger()) 416 BuiltinStr = "S" + BuiltinStr; 417 else 418 BuiltinStr = "U" + BuiltinStr; 419 break; 420 case ScalarTypeKind::Float: 421 switch (ElementBitwidth) { 422 case 16: 423 BuiltinStr += "x"; 424 break; 425 case 32: 426 BuiltinStr += "f"; 427 break; 428 case 64: 429 BuiltinStr += "d"; 430 break; 431 default: 432 llvm_unreachable("Unhandled ElementBitwidth!"); 433 } 434 break; 435 default: 436 llvm_unreachable("ScalarType is invalid!"); 437 } 438 if (IsImmediate) 439 BuiltinStr = "I" + BuiltinStr; 440 if (isScalar()) { 441 if (IsConstant) 442 BuiltinStr += "C"; 443 if (IsPointer) 444 BuiltinStr += "*"; 445 return; 446 } 447 BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; 448 // Pointer to vector types. Defined for Zvlsseg load intrinsics. 449 // Zvlsseg load intrinsics have pointer type arguments to store the loaded 450 // vector values. 451 if (IsPointer) 452 BuiltinStr += "*"; 453 } 454 455 void RVVType::initClangBuiltinStr() { 456 assert(isValid() && "RVVType is invalid"); 457 assert(isVector() && "Handle Vector type only"); 458 459 ClangBuiltinStr = "__rvv_"; 460 switch (ScalarType) { 461 case ScalarTypeKind::Boolean: 462 ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; 463 return; 464 case ScalarTypeKind::Float: 465 ClangBuiltinStr += "float"; 466 break; 467 case ScalarTypeKind::SignedInteger: 468 ClangBuiltinStr += "int"; 469 break; 470 case ScalarTypeKind::UnsignedInteger: 471 ClangBuiltinStr += "uint"; 472 break; 473 default: 474 llvm_unreachable("ScalarTypeKind is invalid"); 475 } 476 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 477 } 478 479 void RVVType::initTypeStr() { 480 assert(isValid() && "RVVType is invalid"); 481 482 if (IsConstant) 483 Str += "const "; 484 485 auto getTypeString = [&](StringRef TypeStr) { 486 if (isScalar()) 487 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 488 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 489 .str(); 490 }; 491 492 switch (ScalarType) { 493 case ScalarTypeKind::Void: 494 Str = "void"; 495 return; 496 case ScalarTypeKind::Size_t: 497 Str = "size_t"; 498 if (IsPointer) 499 Str += " *"; 500 return; 501 case ScalarTypeKind::Ptrdiff_t: 502 Str = "ptrdiff_t"; 503 return; 504 case ScalarTypeKind::UnsignedLong: 505 Str = "unsigned long"; 506 return; 507 case ScalarTypeKind::SignedLong: 508 Str = "long"; 509 return; 510 case ScalarTypeKind::Boolean: 511 if (isScalar()) 512 Str += "bool"; 513 else 514 // Vector bool is special case, the formulate is 515 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 516 Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; 517 break; 518 case ScalarTypeKind::Float: 519 if (isScalar()) { 520 if (ElementBitwidth == 64) 521 Str += "double"; 522 else if (ElementBitwidth == 32) 523 Str += "float"; 524 else if (ElementBitwidth == 16) 525 Str += "_Float16"; 526 else 527 llvm_unreachable("Unhandled floating type."); 528 } else 529 Str += getTypeString("float"); 530 break; 531 case ScalarTypeKind::SignedInteger: 532 Str += getTypeString("int"); 533 break; 534 case ScalarTypeKind::UnsignedInteger: 535 Str += getTypeString("uint"); 536 break; 537 default: 538 llvm_unreachable("ScalarType is invalid!"); 539 } 540 if (IsPointer) 541 Str += " *"; 542 } 543 544 void RVVType::initShortStr() { 545 switch (ScalarType) { 546 case ScalarTypeKind::Boolean: 547 assert(isVector()); 548 ShortStr = "b" + utostr(64 / Scale.getValue()); 549 return; 550 case ScalarTypeKind::Float: 551 ShortStr = "f" + utostr(ElementBitwidth); 552 break; 553 case ScalarTypeKind::SignedInteger: 554 ShortStr = "i" + utostr(ElementBitwidth); 555 break; 556 case ScalarTypeKind::UnsignedInteger: 557 ShortStr = "u" + utostr(ElementBitwidth); 558 break; 559 default: 560 PrintFatalError("Unhandled case!"); 561 } 562 if (isVector()) 563 ShortStr += LMUL.str(); 564 } 565 566 void RVVType::applyBasicType() { 567 switch (BT) { 568 case 'c': 569 ElementBitwidth = 8; 570 ScalarType = ScalarTypeKind::SignedInteger; 571 break; 572 case 's': 573 ElementBitwidth = 16; 574 ScalarType = ScalarTypeKind::SignedInteger; 575 break; 576 case 'i': 577 ElementBitwidth = 32; 578 ScalarType = ScalarTypeKind::SignedInteger; 579 break; 580 case 'l': 581 ElementBitwidth = 64; 582 ScalarType = ScalarTypeKind::SignedInteger; 583 break; 584 case 'x': 585 ElementBitwidth = 16; 586 ScalarType = ScalarTypeKind::Float; 587 break; 588 case 'f': 589 ElementBitwidth = 32; 590 ScalarType = ScalarTypeKind::Float; 591 break; 592 case 'd': 593 ElementBitwidth = 64; 594 ScalarType = ScalarTypeKind::Float; 595 break; 596 default: 597 PrintFatalError("Unhandled type code!"); 598 } 599 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 600 } 601 602 void RVVType::applyModifier(StringRef Transformer) { 603 if (Transformer.empty()) 604 return; 605 // Handle primitive type transformer 606 auto PType = Transformer.back(); 607 switch (PType) { 608 case 'e': 609 Scale = 0; 610 break; 611 case 'v': 612 Scale = LMUL.getScale(ElementBitwidth); 613 break; 614 case 'w': 615 ElementBitwidth *= 2; 616 LMUL *= 2; 617 Scale = LMUL.getScale(ElementBitwidth); 618 break; 619 case 'q': 620 ElementBitwidth *= 4; 621 LMUL *= 4; 622 Scale = LMUL.getScale(ElementBitwidth); 623 break; 624 case 'o': 625 ElementBitwidth *= 8; 626 LMUL *= 8; 627 Scale = LMUL.getScale(ElementBitwidth); 628 break; 629 case 'm': 630 ScalarType = ScalarTypeKind::Boolean; 631 Scale = LMUL.getScale(ElementBitwidth); 632 ElementBitwidth = 1; 633 break; 634 case '0': 635 ScalarType = ScalarTypeKind::Void; 636 break; 637 case 'z': 638 ScalarType = ScalarTypeKind::Size_t; 639 break; 640 case 't': 641 ScalarType = ScalarTypeKind::Ptrdiff_t; 642 break; 643 case 'u': 644 ScalarType = ScalarTypeKind::UnsignedLong; 645 break; 646 case 'l': 647 ScalarType = ScalarTypeKind::SignedLong; 648 break; 649 default: 650 PrintFatalError("Illegal primitive type transformers!"); 651 } 652 Transformer = Transformer.drop_back(); 653 654 // Extract and compute complex type transformer. It can only appear one time. 655 if (Transformer.startswith("(")) { 656 size_t Idx = Transformer.find(')'); 657 assert(Idx != StringRef::npos); 658 StringRef ComplexType = Transformer.slice(1, Idx); 659 Transformer = Transformer.drop_front(Idx + 1); 660 assert(!Transformer.contains('(') && 661 "Only allow one complex type transformer"); 662 663 auto UpdateAndCheckComplexProto = [&]() { 664 Scale = LMUL.getScale(ElementBitwidth); 665 const StringRef VectorPrototypes("vwqom"); 666 if (!VectorPrototypes.contains(PType)) 667 PrintFatalError("Complex type transformer only supports vector type!"); 668 if (Transformer.find_first_of("PCKWS") != StringRef::npos) 669 PrintFatalError( 670 "Illegal type transformer for Complex type transformer"); 671 }; 672 auto ComputeFixedLog2LMUL = 673 [&](StringRef Value, 674 std::function<bool(const int32_t &, const int32_t &)> Compare) { 675 int32_t Log2LMUL; 676 Value.getAsInteger(10, Log2LMUL); 677 if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { 678 ScalarType = Invalid; 679 return false; 680 } 681 // Update new LMUL 682 LMUL = LMULType(Log2LMUL); 683 UpdateAndCheckComplexProto(); 684 return true; 685 }; 686 auto ComplexTT = ComplexType.split(":"); 687 if (ComplexTT.first == "Log2EEW") { 688 uint32_t Log2EEW; 689 ComplexTT.second.getAsInteger(10, Log2EEW); 690 // update new elmul = (eew/sew) * lmul 691 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 692 // update new eew 693 ElementBitwidth = 1 << Log2EEW; 694 ScalarType = ScalarTypeKind::SignedInteger; 695 UpdateAndCheckComplexProto(); 696 } else if (ComplexTT.first == "FixedSEW") { 697 uint32_t NewSEW; 698 ComplexTT.second.getAsInteger(10, NewSEW); 699 // Set invalid type if src and dst SEW are same. 700 if (ElementBitwidth == NewSEW) { 701 ScalarType = Invalid; 702 return; 703 } 704 // Update new SEW 705 ElementBitwidth = NewSEW; 706 UpdateAndCheckComplexProto(); 707 } else if (ComplexTT.first == "LFixedLog2LMUL") { 708 // New LMUL should be larger than old 709 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>())) 710 return; 711 } else if (ComplexTT.first == "SFixedLog2LMUL") { 712 // New LMUL should be smaller than old 713 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>())) 714 return; 715 } else { 716 PrintFatalError("Illegal complex type transformers!"); 717 } 718 } 719 720 // Compute the remain type transformers 721 for (char I : Transformer) { 722 switch (I) { 723 case 'P': 724 if (IsConstant) 725 PrintFatalError("'P' transformer cannot be used after 'C'"); 726 if (IsPointer) 727 PrintFatalError("'P' transformer cannot be used twice"); 728 IsPointer = true; 729 break; 730 case 'C': 731 if (IsConstant) 732 PrintFatalError("'C' transformer cannot be used twice"); 733 IsConstant = true; 734 break; 735 case 'K': 736 IsImmediate = true; 737 break; 738 case 'U': 739 ScalarType = ScalarTypeKind::UnsignedInteger; 740 break; 741 case 'I': 742 ScalarType = ScalarTypeKind::SignedInteger; 743 break; 744 case 'F': 745 ScalarType = ScalarTypeKind::Float; 746 break; 747 case 'S': 748 LMUL = LMULType(0); 749 // Update ElementBitwidth need to update Scale too. 750 Scale = LMUL.getScale(ElementBitwidth); 751 break; 752 default: 753 PrintFatalError("Illegal non-primitive type transformer!"); 754 } 755 } 756 } 757 758 //===----------------------------------------------------------------------===// 759 // RVVIntrinsic implementation 760 //===----------------------------------------------------------------------===// 761 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, 762 StringRef NewMangledName, StringRef MangledSuffix, 763 StringRef IRName, bool IsMask, 764 bool HasMaskedOffOperand, bool HasVL, bool HasPolicy, 765 bool HasNoMaskedOverloaded, bool HasAutoDef, 766 StringRef ManualCodegen, const RVVTypes &OutInTypes, 767 const std::vector<int64_t> &NewIntrinsicTypes, 768 const std::vector<StringRef> &RequiredExtensions, 769 unsigned NF) 770 : IRName(IRName), IsMask(IsMask), HasVL(HasVL), HasPolicy(HasPolicy), 771 HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef), 772 ManualCodegen(ManualCodegen.str()), NF(NF) { 773 774 // Init BuiltinName, Name and MangledName 775 BuiltinName = NewName.str(); 776 Name = BuiltinName; 777 if (NewMangledName.empty()) 778 MangledName = NewName.split("_").first.str(); 779 else 780 MangledName = NewMangledName.str(); 781 if (!Suffix.empty()) 782 Name += "_" + Suffix.str(); 783 if (!MangledSuffix.empty()) 784 MangledName += "_" + MangledSuffix.str(); 785 if (IsMask) { 786 BuiltinName += "_m"; 787 Name += "_m"; 788 } 789 790 // Init RISC-V extensions 791 for (const auto &T : OutInTypes) { 792 if (T->isFloatVector(16) || T->isFloat(16)) 793 RISCVExtensions |= RISCVExtension::Zfh; 794 else if (T->isFloatVector(32) || T->isFloat(32)) 795 RISCVExtensions |= RISCVExtension::F; 796 else if (T->isFloatVector(64) || T->isFloat(64)) 797 RISCVExtensions |= RISCVExtension::D; 798 } 799 for (auto Extension : RequiredExtensions) { 800 if (Extension == "Zvlsseg") 801 RISCVExtensions |= RISCVExtension::Zvlsseg; 802 if (Extension == "RV64") 803 RISCVExtensions |= RISCVExtension::RV64; 804 } 805 806 // Init OutputType and InputTypes 807 OutputType = OutInTypes[0]; 808 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 809 810 // IntrinsicTypes is nonmasked version index. Need to update it 811 // if there is maskedoff operand (It is always in first operand). 812 IntrinsicTypes = NewIntrinsicTypes; 813 if (IsMask && HasMaskedOffOperand) { 814 for (auto &I : IntrinsicTypes) { 815 if (I >= 0) 816 I += NF; 817 } 818 } 819 } 820 821 std::string RVVIntrinsic::getBuiltinTypeStr() const { 822 std::string S; 823 S += OutputType->getBuiltinStr(); 824 for (const auto &T : InputTypes) { 825 S += T->getBuiltinStr(); 826 } 827 return S; 828 } 829 830 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const { 831 if (!getIRName().empty()) 832 OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n"; 833 if (NF >= 2) 834 OS << " NF = " + utostr(getNF()) + ";\n"; 835 if (hasManualCodegen()) { 836 OS << ManualCodegen; 837 OS << "break;\n"; 838 return; 839 } 840 841 if (isMask()) { 842 if (hasVL()) { 843 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n"; 844 if (hasPolicy()) 845 OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType()," 846 " TAIL_UNDISTURBED));\n"; 847 } else { 848 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n"; 849 } 850 } 851 852 OS << " IntrinsicTypes = {"; 853 ListSeparator LS; 854 for (const auto &Idx : IntrinsicTypes) { 855 if (Idx == -1) 856 OS << LS << "ResultType"; 857 else 858 OS << LS << "Ops[" << Idx << "]->getType()"; 859 } 860 861 // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is 862 // always last operand. 863 if (hasVL()) 864 OS << ", Ops.back()->getType()"; 865 OS << "};\n"; 866 OS << " break;\n"; 867 } 868 869 void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const { 870 OS << "__attribute__((__clang_builtin_alias__("; 871 OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; 872 OS << OutputType->getTypeStr() << " " << getName() << "("; 873 // Emit function arguments 874 if (!InputTypes.empty()) { 875 ListSeparator LS; 876 for (unsigned i = 0; i < InputTypes.size(); ++i) 877 OS << LS << InputTypes[i]->getTypeStr(); 878 } 879 OS << ");\n"; 880 } 881 882 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const { 883 OS << "__attribute__((__clang_builtin_alias__("; 884 OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; 885 OS << OutputType->getTypeStr() << " " << getMangledName() << "("; 886 // Emit function arguments 887 if (!InputTypes.empty()) { 888 ListSeparator LS; 889 for (unsigned i = 0; i < InputTypes.size(); ++i) 890 OS << LS << InputTypes[i]->getTypeStr(); 891 } 892 OS << ");\n"; 893 } 894 895 //===----------------------------------------------------------------------===// 896 // RVVEmitter implementation 897 //===----------------------------------------------------------------------===// 898 void RVVEmitter::createHeader(raw_ostream &OS) { 899 900 OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics " 901 "-------------------===\n" 902 " *\n" 903 " *\n" 904 " * Part of the LLVM Project, under the Apache License v2.0 with LLVM " 905 "Exceptions.\n" 906 " * See https://llvm.org/LICENSE.txt for license information.\n" 907 " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n" 908 " *\n" 909 " *===-----------------------------------------------------------------" 910 "------===\n" 911 " */\n\n"; 912 913 OS << "#ifndef __RISCV_VECTOR_H\n"; 914 OS << "#define __RISCV_VECTOR_H\n\n"; 915 916 OS << "#include <stdint.h>\n"; 917 OS << "#include <stddef.h>\n\n"; 918 919 OS << "#ifndef __riscv_vector\n"; 920 OS << "#error \"Vector intrinsics require the vector extension.\"\n"; 921 OS << "#endif\n\n"; 922 923 OS << "#ifdef __cplusplus\n"; 924 OS << "extern \"C\" {\n"; 925 OS << "#endif\n\n"; 926 927 createRVVHeaders(OS); 928 929 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 930 createRVVIntrinsics(Defs); 931 932 // Print header code 933 if (!HeaderCode.empty()) { 934 OS << HeaderCode; 935 } 936 937 auto printType = [&](auto T) { 938 OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr() 939 << ";\n"; 940 }; 941 942 constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; 943 // Print RVV boolean types. 944 for (int Log2LMUL : Log2LMULs) { 945 auto T = computeType('c', Log2LMUL, "m"); 946 if (T.hasValue()) 947 printType(T.getValue()); 948 } 949 // Print RVV int/float types. 950 for (char I : StringRef("csil")) { 951 for (int Log2LMUL : Log2LMULs) { 952 auto T = computeType(I, Log2LMUL, "v"); 953 if (T.hasValue()) { 954 printType(T.getValue()); 955 auto UT = computeType(I, Log2LMUL, "Uv"); 956 printType(UT.getValue()); 957 } 958 } 959 } 960 OS << "#if defined(__riscv_zfh)\n"; 961 for (int Log2LMUL : Log2LMULs) { 962 auto T = computeType('x', Log2LMUL, "v"); 963 if (T.hasValue()) 964 printType(T.getValue()); 965 } 966 OS << "#endif\n"; 967 968 OS << "#if defined(__riscv_f)\n"; 969 for (int Log2LMUL : Log2LMULs) { 970 auto T = computeType('f', Log2LMUL, "v"); 971 if (T.hasValue()) 972 printType(T.getValue()); 973 } 974 OS << "#endif\n"; 975 976 OS << "#if defined(__riscv_d)\n"; 977 for (int Log2LMUL : Log2LMULs) { 978 auto T = computeType('d', Log2LMUL, "v"); 979 if (T.hasValue()) 980 printType(T.getValue()); 981 } 982 OS << "#endif\n\n"; 983 984 // The same extension include in the same arch guard marco. 985 llvm::stable_sort(Defs, [](const std::unique_ptr<RVVIntrinsic> &A, 986 const std::unique_ptr<RVVIntrinsic> &B) { 987 return A->getRISCVExtensions() < B->getRISCVExtensions(); 988 }); 989 990 OS << "#define __rvv_ai static __inline__\n"; 991 992 // Print intrinsic functions with macro 993 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 994 OS << "__rvv_ai "; 995 Inst.emitIntrinsicFuncDef(OS); 996 }); 997 998 OS << "#undef __rvv_ai\n\n"; 999 1000 OS << "#define __riscv_v_intrinsic_overloading 1\n"; 1001 1002 // Print Overloaded APIs 1003 OS << "#define __rvv_aio static __inline__ " 1004 "__attribute__((__overloadable__))\n"; 1005 1006 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 1007 if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded()) 1008 return; 1009 OS << "__rvv_aio "; 1010 Inst.emitMangledFuncDef(OS); 1011 }); 1012 1013 OS << "#undef __rvv_aio\n"; 1014 1015 OS << "\n#ifdef __cplusplus\n"; 1016 OS << "}\n"; 1017 OS << "#endif // __cplusplus\n"; 1018 OS << "#endif // __RISCV_VECTOR_H\n"; 1019 } 1020 1021 void RVVEmitter::createBuiltins(raw_ostream &OS) { 1022 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1023 createRVVIntrinsics(Defs); 1024 1025 // Map to keep track of which builtin names have already been emitted. 1026 StringMap<RVVIntrinsic *> BuiltinMap; 1027 1028 OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n"; 1029 OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, " 1030 "ATTRS, \"experimental-v\")\n"; 1031 OS << "#endif\n"; 1032 for (auto &Def : Defs) { 1033 auto P = 1034 BuiltinMap.insert(std::make_pair(Def->getBuiltinName(), Def.get())); 1035 if (!P.second) { 1036 // Verify that this would have produced the same builtin definition. 1037 if (P.first->second->hasAutoDef() != Def->hasAutoDef()) { 1038 PrintFatalError("Builtin with same name has different hasAutoDef"); 1039 } else if (!Def->hasAutoDef() && P.first->second->getBuiltinTypeStr() != 1040 Def->getBuiltinTypeStr()) { 1041 PrintFatalError("Builtin with same name has different type string"); 1042 } 1043 continue; 1044 } 1045 1046 OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getBuiltinName() << ",\""; 1047 if (!Def->hasAutoDef()) 1048 OS << Def->getBuiltinTypeStr(); 1049 OS << "\", \"n\")\n"; 1050 } 1051 OS << "#undef RISCVV_BUILTIN\n"; 1052 } 1053 1054 void RVVEmitter::createCodeGen(raw_ostream &OS) { 1055 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1056 createRVVIntrinsics(Defs); 1057 // IR name could be empty, use the stable sort preserves the relative order. 1058 llvm::stable_sort(Defs, [](const std::unique_ptr<RVVIntrinsic> &A, 1059 const std::unique_ptr<RVVIntrinsic> &B) { 1060 return A->getIRName() < B->getIRName(); 1061 }); 1062 1063 // Map to keep track of which builtin names have already been emitted. 1064 StringMap<RVVIntrinsic *> BuiltinMap; 1065 1066 // Print switch body when the ir name or ManualCodegen changes from previous 1067 // iteration. 1068 RVVIntrinsic *PrevDef = Defs.begin()->get(); 1069 for (auto &Def : Defs) { 1070 StringRef CurIRName = Def->getIRName(); 1071 if (CurIRName != PrevDef->getIRName() || 1072 (Def->getManualCodegen() != PrevDef->getManualCodegen())) { 1073 PrevDef->emitCodeGenSwitchBody(OS); 1074 } 1075 PrevDef = Def.get(); 1076 1077 auto P = 1078 BuiltinMap.insert(std::make_pair(Def->getBuiltinName(), Def.get())); 1079 if (P.second) { 1080 OS << "case RISCVVector::BI__builtin_rvv_" << Def->getBuiltinName() 1081 << ":\n"; 1082 continue; 1083 } 1084 1085 if (P.first->second->getIRName() != Def->getIRName()) 1086 PrintFatalError("Builtin with same name has different IRName"); 1087 else if (P.first->second->getManualCodegen() != Def->getManualCodegen()) 1088 PrintFatalError("Builtin with same name has different ManualCodegen"); 1089 else if (P.first->second->getNF() != Def->getNF()) 1090 PrintFatalError("Builtin with same name has different NF"); 1091 else if (P.first->second->isMask() != Def->isMask()) 1092 PrintFatalError("Builtin with same name has different isMask"); 1093 else if (P.first->second->hasVL() != Def->hasVL()) 1094 PrintFatalError("Builtin with same name has different HasPolicy"); 1095 else if (P.first->second->hasPolicy() != Def->hasPolicy()) 1096 PrintFatalError("Builtin with same name has different HasPolicy"); 1097 else if (P.first->second->getIntrinsicTypes() != Def->getIntrinsicTypes()) 1098 PrintFatalError("Builtin with same name has different IntrinsicTypes"); 1099 } 1100 Defs.back()->emitCodeGenSwitchBody(OS); 1101 OS << "\n"; 1102 } 1103 1104 void RVVEmitter::parsePrototypes(StringRef Prototypes, 1105 std::function<void(StringRef)> Handler) { 1106 const StringRef Primaries("evwqom0ztul"); 1107 while (!Prototypes.empty()) { 1108 size_t Idx = 0; 1109 // Skip over complex prototype because it could contain primitive type 1110 // character. 1111 if (Prototypes[0] == '(') 1112 Idx = Prototypes.find_first_of(')'); 1113 Idx = Prototypes.find_first_of(Primaries, Idx); 1114 assert(Idx != StringRef::npos); 1115 Handler(Prototypes.slice(0, Idx + 1)); 1116 Prototypes = Prototypes.drop_front(Idx + 1); 1117 } 1118 } 1119 1120 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, 1121 StringRef Prototypes) { 1122 SmallVector<std::string> SuffixStrs; 1123 parsePrototypes(Prototypes, [&](StringRef Proto) { 1124 auto T = computeType(Type, Log2LMUL, Proto); 1125 SuffixStrs.push_back(T.getValue()->getShortStr()); 1126 }); 1127 return join(SuffixStrs, "_"); 1128 } 1129 1130 void RVVEmitter::createRVVIntrinsics( 1131 std::vector<std::unique_ptr<RVVIntrinsic>> &Out) { 1132 std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin"); 1133 for (auto *R : RV) { 1134 StringRef Name = R->getValueAsString("Name"); 1135 StringRef SuffixProto = R->getValueAsString("Suffix"); 1136 StringRef MangledName = R->getValueAsString("MangledName"); 1137 StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix"); 1138 StringRef Prototypes = R->getValueAsString("Prototype"); 1139 StringRef TypeRange = R->getValueAsString("TypeRange"); 1140 bool HasMask = R->getValueAsBit("HasMask"); 1141 bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand"); 1142 bool HasVL = R->getValueAsBit("HasVL"); 1143 bool HasPolicy = R->getValueAsBit("HasPolicy"); 1144 bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded"); 1145 std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL"); 1146 StringRef ManualCodegen = R->getValueAsString("ManualCodegen"); 1147 StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask"); 1148 std::vector<int64_t> IntrinsicTypes = 1149 R->getValueAsListOfInts("IntrinsicTypes"); 1150 std::vector<StringRef> RequiredExtensions = 1151 R->getValueAsListOfStrings("RequiredExtensions"); 1152 StringRef IRName = R->getValueAsString("IRName"); 1153 StringRef IRNameMask = R->getValueAsString("IRNameMask"); 1154 unsigned NF = R->getValueAsInt("NF"); 1155 1156 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode"); 1157 bool HasAutoDef = HeaderCodeStr.empty(); 1158 if (!HeaderCodeStr.empty()) { 1159 HeaderCode += HeaderCodeStr.str(); 1160 } 1161 // Parse prototype and create a list of primitive type with transformers 1162 // (operand) in ProtoSeq. ProtoSeq[0] is output operand. 1163 SmallVector<std::string> ProtoSeq; 1164 parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { 1165 ProtoSeq.push_back(Proto.str()); 1166 }); 1167 1168 // Compute Builtin types 1169 SmallVector<std::string> ProtoMaskSeq = ProtoSeq; 1170 if (HasMask) { 1171 // If HasMaskedOffOperand, insert result type as first input operand. 1172 if (HasMaskedOffOperand) { 1173 if (NF == 1) { 1174 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]); 1175 } else { 1176 // Convert 1177 // (void, op0 address, op1 address, ...) 1178 // to 1179 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 1180 for (unsigned I = 0; I < NF; ++I) 1181 ProtoMaskSeq.insert( 1182 ProtoMaskSeq.begin() + NF + 1, 1183 ProtoSeq[1].substr(1)); // Use substr(1) to skip '*' 1184 } 1185 } 1186 if (HasMaskedOffOperand && NF > 1) { 1187 // Convert 1188 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 1189 // to 1190 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, 1191 // ...) 1192 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m"); 1193 } else { 1194 // If HasMask, insert 'm' as first input operand. 1195 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); 1196 } 1197 } 1198 // If HasVL, append 'z' to last operand 1199 if (HasVL) { 1200 ProtoSeq.push_back("z"); 1201 ProtoMaskSeq.push_back("z"); 1202 } 1203 1204 // Create Intrinsics for each type and LMUL. 1205 for (char I : TypeRange) { 1206 for (int Log2LMUL : Log2LMULList) { 1207 Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); 1208 // Ignored to create new intrinsic if there are any illegal types. 1209 if (!Types.hasValue()) 1210 continue; 1211 1212 auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); 1213 auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); 1214 // Create a non-mask intrinsic 1215 Out.push_back(std::make_unique<RVVIntrinsic>( 1216 Name, SuffixStr, MangledName, MangledSuffixStr, IRName, 1217 /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, HasPolicy, 1218 HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(), 1219 IntrinsicTypes, RequiredExtensions, NF)); 1220 if (HasMask) { 1221 // Create a mask intrinsic 1222 Optional<RVVTypes> MaskTypes = 1223 computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); 1224 Out.push_back(std::make_unique<RVVIntrinsic>( 1225 Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask, 1226 /*IsMask=*/true, HasMaskedOffOperand, HasVL, HasPolicy, 1227 HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask, 1228 MaskTypes.getValue(), IntrinsicTypes, RequiredExtensions, NF)); 1229 } 1230 } // end for Log2LMULList 1231 } // end for TypeRange 1232 } 1233 } 1234 1235 void RVVEmitter::createRVVHeaders(raw_ostream &OS) { 1236 std::vector<Record *> RVVHeaders = 1237 Records.getAllDerivedDefinitions("RVVHeader"); 1238 for (auto *R : RVVHeaders) { 1239 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode"); 1240 OS << HeaderCodeStr.str(); 1241 } 1242 } 1243 1244 Optional<RVVTypes> 1245 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 1246 ArrayRef<std::string> PrototypeSeq) { 1247 // LMUL x NF must be less than or equal to 8. 1248 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) 1249 return llvm::None; 1250 1251 RVVTypes Types; 1252 for (const std::string &Proto : PrototypeSeq) { 1253 auto T = computeType(BT, Log2LMUL, Proto); 1254 if (!T.hasValue()) 1255 return llvm::None; 1256 // Record legal type index 1257 Types.push_back(T.getValue()); 1258 } 1259 return Types; 1260 } 1261 1262 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL, 1263 StringRef Proto) { 1264 std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); 1265 // Search first 1266 auto It = LegalTypes.find(Idx); 1267 if (It != LegalTypes.end()) 1268 return &(It->second); 1269 if (IllegalTypes.count(Idx)) 1270 return llvm::None; 1271 // Compute type and record the result. 1272 RVVType T(BT, Log2LMUL, Proto); 1273 if (T.isValid()) { 1274 // Record legal type index and value. 1275 LegalTypes.insert({Idx, T}); 1276 return &(LegalTypes[Idx]); 1277 } 1278 // Record illegal type index. 1279 IllegalTypes.insert(Idx); 1280 return llvm::None; 1281 } 1282 1283 void RVVEmitter::emitArchMacroAndBody( 1284 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS, 1285 std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) { 1286 uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions(); 1287 bool NeedEndif = emitExtDefStr(PrevExt, OS); 1288 for (auto &Def : Defs) { 1289 uint8_t CurExt = Def->getRISCVExtensions(); 1290 if (CurExt != PrevExt) { 1291 if (NeedEndif) 1292 OS << "#endif\n\n"; 1293 NeedEndif = emitExtDefStr(CurExt, OS); 1294 PrevExt = CurExt; 1295 } 1296 if (Def->hasAutoDef()) 1297 PrintBody(OS, *Def); 1298 } 1299 if (NeedEndif) 1300 OS << "#endif\n\n"; 1301 } 1302 1303 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) { 1304 if (Extents == RISCVExtension::Basic) 1305 return false; 1306 OS << "#if "; 1307 ListSeparator LS(" && "); 1308 if (Extents & RISCVExtension::F) 1309 OS << LS << "defined(__riscv_f)"; 1310 if (Extents & RISCVExtension::D) 1311 OS << LS << "defined(__riscv_d)"; 1312 if (Extents & RISCVExtension::Zfh) 1313 OS << LS << "defined(__riscv_zfh)"; 1314 if (Extents & RISCVExtension::Zvlsseg) 1315 OS << LS << "defined(__riscv_zvlsseg)"; 1316 if (Extents & RISCVExtension::RV64) 1317 OS << LS << "(__riscv_xlen == 64)"; 1318 OS << "\n"; 1319 return true; 1320 } 1321 1322 namespace clang { 1323 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) { 1324 RVVEmitter(Records).createHeader(OS); 1325 } 1326 1327 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) { 1328 RVVEmitter(Records).createBuiltins(OS); 1329 } 1330 1331 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) { 1332 RVVEmitter(Records).createCodeGen(OS); 1333 } 1334 1335 } // End namespace clang 1336