1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/Optional.h"
12 #include "llvm/ADT/SmallSet.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringSet.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <numeric>
19 #include <set>
20 #include <unordered_map>
21 
22 using namespace llvm;
23 
24 namespace clang {
25 namespace RISCV {
26 
27 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
28     BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
29 const PrototypeDescriptor PrototypeDescriptor::VL =
30     PrototypeDescriptor(BaseTypeModifier::SizeT);
31 const PrototypeDescriptor PrototypeDescriptor::Vector =
32     PrototypeDescriptor(BaseTypeModifier::Vector);
33 
34 // Concat BasicType, LMUL and Proto as key
35 static std::unordered_map<uint64_t, RVVType> LegalTypes;
36 static std::set<uint64_t> IllegalTypes;
37 
38 //===----------------------------------------------------------------------===//
39 // Type implementation
40 //===----------------------------------------------------------------------===//
41 
42 LMULType::LMULType(int NewLog2LMUL) {
43   // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
44   assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
45   Log2LMUL = NewLog2LMUL;
46 }
47 
48 std::string LMULType::str() const {
49   if (Log2LMUL < 0)
50     return "mf" + utostr(1ULL << (-Log2LMUL));
51   return "m" + utostr(1ULL << Log2LMUL);
52 }
53 
54 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
55   int Log2ScaleResult = 0;
56   switch (ElementBitwidth) {
57   default:
58     break;
59   case 8:
60     Log2ScaleResult = Log2LMUL + 3;
61     break;
62   case 16:
63     Log2ScaleResult = Log2LMUL + 2;
64     break;
65   case 32:
66     Log2ScaleResult = Log2LMUL + 1;
67     break;
68   case 64:
69     Log2ScaleResult = Log2LMUL;
70     break;
71   }
72   // Illegal vscale result would be less than 1
73   if (Log2ScaleResult < 0)
74     return llvm::None;
75   return 1 << Log2ScaleResult;
76 }
77 
78 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
79 
80 RVVType::RVVType(BasicType BT, int Log2LMUL,
81                  const PrototypeDescriptor &prototype)
82     : BT(BT), LMUL(LMULType(Log2LMUL)) {
83   applyBasicType();
84   applyModifier(prototype);
85   Valid = verifyType();
86   if (Valid) {
87     initBuiltinStr();
88     initTypeStr();
89     if (isVector()) {
90       initClangBuiltinStr();
91     }
92   }
93 }
94 
95 // clang-format off
96 // boolean type are encoded the ratio of n (SEW/LMUL)
97 // SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
98 // c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
99 // IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
100 
101 // type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
102 // --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
103 // i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
104 // i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
105 // i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
106 // i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
107 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
108 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
109 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
110 // clang-format on
111 
112 bool RVVType::verifyType() const {
113   if (ScalarType == Invalid)
114     return false;
115   if (isScalar())
116     return true;
117   if (!Scale.hasValue())
118     return false;
119   if (isFloat() && ElementBitwidth == 8)
120     return false;
121   unsigned V = Scale.getValue();
122   switch (ElementBitwidth) {
123   case 1:
124   case 8:
125     // Check Scale is 1,2,4,8,16,32,64
126     return (V <= 64 && isPowerOf2_32(V));
127   case 16:
128     // Check Scale is 1,2,4,8,16,32
129     return (V <= 32 && isPowerOf2_32(V));
130   case 32:
131     // Check Scale is 1,2,4,8,16
132     return (V <= 16 && isPowerOf2_32(V));
133   case 64:
134     // Check Scale is 1,2,4,8
135     return (V <= 8 && isPowerOf2_32(V));
136   }
137   return false;
138 }
139 
140 void RVVType::initBuiltinStr() {
141   assert(isValid() && "RVVType is invalid");
142   switch (ScalarType) {
143   case ScalarTypeKind::Void:
144     BuiltinStr = "v";
145     return;
146   case ScalarTypeKind::Size_t:
147     BuiltinStr = "z";
148     if (IsImmediate)
149       BuiltinStr = "I" + BuiltinStr;
150     if (IsPointer)
151       BuiltinStr += "*";
152     return;
153   case ScalarTypeKind::Ptrdiff_t:
154     BuiltinStr = "Y";
155     return;
156   case ScalarTypeKind::UnsignedLong:
157     BuiltinStr = "ULi";
158     return;
159   case ScalarTypeKind::SignedLong:
160     BuiltinStr = "Li";
161     return;
162   case ScalarTypeKind::Boolean:
163     assert(ElementBitwidth == 1);
164     BuiltinStr += "b";
165     break;
166   case ScalarTypeKind::SignedInteger:
167   case ScalarTypeKind::UnsignedInteger:
168     switch (ElementBitwidth) {
169     case 8:
170       BuiltinStr += "c";
171       break;
172     case 16:
173       BuiltinStr += "s";
174       break;
175     case 32:
176       BuiltinStr += "i";
177       break;
178     case 64:
179       BuiltinStr += "Wi";
180       break;
181     default:
182       llvm_unreachable("Unhandled ElementBitwidth!");
183     }
184     if (isSignedInteger())
185       BuiltinStr = "S" + BuiltinStr;
186     else
187       BuiltinStr = "U" + BuiltinStr;
188     break;
189   case ScalarTypeKind::Float:
190     switch (ElementBitwidth) {
191     case 16:
192       BuiltinStr += "x";
193       break;
194     case 32:
195       BuiltinStr += "f";
196       break;
197     case 64:
198       BuiltinStr += "d";
199       break;
200     default:
201       llvm_unreachable("Unhandled ElementBitwidth!");
202     }
203     break;
204   default:
205     llvm_unreachable("ScalarType is invalid!");
206   }
207   if (IsImmediate)
208     BuiltinStr = "I" + BuiltinStr;
209   if (isScalar()) {
210     if (IsConstant)
211       BuiltinStr += "C";
212     if (IsPointer)
213       BuiltinStr += "*";
214     return;
215   }
216   BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
217   // Pointer to vector types. Defined for segment load intrinsics.
218   // segment load intrinsics have pointer type arguments to store the loaded
219   // vector values.
220   if (IsPointer)
221     BuiltinStr += "*";
222 }
223 
224 void RVVType::initClangBuiltinStr() {
225   assert(isValid() && "RVVType is invalid");
226   assert(isVector() && "Handle Vector type only");
227 
228   ClangBuiltinStr = "__rvv_";
229   switch (ScalarType) {
230   case ScalarTypeKind::Boolean:
231     ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
232     return;
233   case ScalarTypeKind::Float:
234     ClangBuiltinStr += "float";
235     break;
236   case ScalarTypeKind::SignedInteger:
237     ClangBuiltinStr += "int";
238     break;
239   case ScalarTypeKind::UnsignedInteger:
240     ClangBuiltinStr += "uint";
241     break;
242   default:
243     llvm_unreachable("ScalarTypeKind is invalid");
244   }
245   ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
246 }
247 
248 void RVVType::initTypeStr() {
249   assert(isValid() && "RVVType is invalid");
250 
251   if (IsConstant)
252     Str += "const ";
253 
254   auto getTypeString = [&](StringRef TypeStr) {
255     if (isScalar())
256       return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
257     return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
258         .str();
259   };
260 
261   switch (ScalarType) {
262   case ScalarTypeKind::Void:
263     Str = "void";
264     return;
265   case ScalarTypeKind::Size_t:
266     Str = "size_t";
267     if (IsPointer)
268       Str += " *";
269     return;
270   case ScalarTypeKind::Ptrdiff_t:
271     Str = "ptrdiff_t";
272     return;
273   case ScalarTypeKind::UnsignedLong:
274     Str = "unsigned long";
275     return;
276   case ScalarTypeKind::SignedLong:
277     Str = "long";
278     return;
279   case ScalarTypeKind::Boolean:
280     if (isScalar())
281       Str += "bool";
282     else
283       // Vector bool is special case, the formulate is
284       // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
285       Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
286     break;
287   case ScalarTypeKind::Float:
288     if (isScalar()) {
289       if (ElementBitwidth == 64)
290         Str += "double";
291       else if (ElementBitwidth == 32)
292         Str += "float";
293       else if (ElementBitwidth == 16)
294         Str += "_Float16";
295       else
296         llvm_unreachable("Unhandled floating type.");
297     } else
298       Str += getTypeString("float");
299     break;
300   case ScalarTypeKind::SignedInteger:
301     Str += getTypeString("int");
302     break;
303   case ScalarTypeKind::UnsignedInteger:
304     Str += getTypeString("uint");
305     break;
306   default:
307     llvm_unreachable("ScalarType is invalid!");
308   }
309   if (IsPointer)
310     Str += " *";
311 }
312 
313 void RVVType::initShortStr() {
314   switch (ScalarType) {
315   case ScalarTypeKind::Boolean:
316     assert(isVector());
317     ShortStr = "b" + utostr(64 / Scale.getValue());
318     return;
319   case ScalarTypeKind::Float:
320     ShortStr = "f" + utostr(ElementBitwidth);
321     break;
322   case ScalarTypeKind::SignedInteger:
323     ShortStr = "i" + utostr(ElementBitwidth);
324     break;
325   case ScalarTypeKind::UnsignedInteger:
326     ShortStr = "u" + utostr(ElementBitwidth);
327     break;
328   default:
329     llvm_unreachable("Unhandled case!");
330   }
331   if (isVector())
332     ShortStr += LMUL.str();
333 }
334 
335 void RVVType::applyBasicType() {
336   switch (BT) {
337   case BasicType::Int8:
338     ElementBitwidth = 8;
339     ScalarType = ScalarTypeKind::SignedInteger;
340     break;
341   case BasicType::Int16:
342     ElementBitwidth = 16;
343     ScalarType = ScalarTypeKind::SignedInteger;
344     break;
345   case BasicType::Int32:
346     ElementBitwidth = 32;
347     ScalarType = ScalarTypeKind::SignedInteger;
348     break;
349   case BasicType::Int64:
350     ElementBitwidth = 64;
351     ScalarType = ScalarTypeKind::SignedInteger;
352     break;
353   case BasicType::Float16:
354     ElementBitwidth = 16;
355     ScalarType = ScalarTypeKind::Float;
356     break;
357   case BasicType::Float32:
358     ElementBitwidth = 32;
359     ScalarType = ScalarTypeKind::Float;
360     break;
361   case BasicType::Float64:
362     ElementBitwidth = 64;
363     ScalarType = ScalarTypeKind::Float;
364     break;
365   default:
366     llvm_unreachable("Unhandled type code!");
367   }
368   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
369 }
370 
371 Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
372     llvm::StringRef PrototypeDescriptorStr) {
373   PrototypeDescriptor PD;
374   BaseTypeModifier PT = BaseTypeModifier::Invalid;
375   VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
376 
377   if (PrototypeDescriptorStr.empty())
378     return PD;
379 
380   // Handle base type modifier
381   auto PType = PrototypeDescriptorStr.back();
382   switch (PType) {
383   case 'e':
384     PT = BaseTypeModifier::Scalar;
385     break;
386   case 'v':
387     PT = BaseTypeModifier::Vector;
388     break;
389   case 'w':
390     PT = BaseTypeModifier::Vector;
391     VTM = VectorTypeModifier::Widening2XVector;
392     break;
393   case 'q':
394     PT = BaseTypeModifier::Vector;
395     VTM = VectorTypeModifier::Widening4XVector;
396     break;
397   case 'o':
398     PT = BaseTypeModifier::Vector;
399     VTM = VectorTypeModifier::Widening8XVector;
400     break;
401   case 'm':
402     PT = BaseTypeModifier::Vector;
403     VTM = VectorTypeModifier::MaskVector;
404     break;
405   case '0':
406     PT = BaseTypeModifier::Void;
407     break;
408   case 'z':
409     PT = BaseTypeModifier::SizeT;
410     break;
411   case 't':
412     PT = BaseTypeModifier::Ptrdiff;
413     break;
414   case 'u':
415     PT = BaseTypeModifier::UnsignedLong;
416     break;
417   case 'l':
418     PT = BaseTypeModifier::SignedLong;
419     break;
420   default:
421     llvm_unreachable("Illegal primitive type transformers!");
422   }
423   PD.PT = static_cast<uint8_t>(PT);
424   PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
425 
426   // Compute the vector type transformers, it can only appear one time.
427   if (PrototypeDescriptorStr.startswith("(")) {
428     assert(VTM == VectorTypeModifier::NoModifier &&
429            "VectorTypeModifier should only have one modifier");
430     size_t Idx = PrototypeDescriptorStr.find(')');
431     assert(Idx != StringRef::npos);
432     StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
433     PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
434     assert(!PrototypeDescriptorStr.contains('(') &&
435            "Only allow one vector type modifier");
436 
437     auto ComplexTT = ComplexType.split(":");
438     if (ComplexTT.first == "Log2EEW") {
439       uint32_t Log2EEW;
440       if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
441         llvm_unreachable("Invalid Log2EEW value!");
442         return None;
443       }
444       switch (Log2EEW) {
445       case 3:
446         VTM = VectorTypeModifier::Log2EEW3;
447         break;
448       case 4:
449         VTM = VectorTypeModifier::Log2EEW4;
450         break;
451       case 5:
452         VTM = VectorTypeModifier::Log2EEW5;
453         break;
454       case 6:
455         VTM = VectorTypeModifier::Log2EEW6;
456         break;
457       default:
458         llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
459         return None;
460       }
461     } else if (ComplexTT.first == "FixedSEW") {
462       uint32_t NewSEW;
463       if (ComplexTT.second.getAsInteger(10, NewSEW)) {
464         llvm_unreachable("Invalid FixedSEW value!");
465         return None;
466       }
467       switch (NewSEW) {
468       case 8:
469         VTM = VectorTypeModifier::FixedSEW8;
470         break;
471       case 16:
472         VTM = VectorTypeModifier::FixedSEW16;
473         break;
474       case 32:
475         VTM = VectorTypeModifier::FixedSEW32;
476         break;
477       case 64:
478         VTM = VectorTypeModifier::FixedSEW64;
479         break;
480       default:
481         llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
482         return None;
483       }
484     } else if (ComplexTT.first == "LFixedLog2LMUL") {
485       int32_t Log2LMUL;
486       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
487         llvm_unreachable("Invalid LFixedLog2LMUL value!");
488         return None;
489       }
490       switch (Log2LMUL) {
491       case -3:
492         VTM = VectorTypeModifier::LFixedLog2LMULN3;
493         break;
494       case -2:
495         VTM = VectorTypeModifier::LFixedLog2LMULN2;
496         break;
497       case -1:
498         VTM = VectorTypeModifier::LFixedLog2LMULN1;
499         break;
500       case 0:
501         VTM = VectorTypeModifier::LFixedLog2LMUL0;
502         break;
503       case 1:
504         VTM = VectorTypeModifier::LFixedLog2LMUL1;
505         break;
506       case 2:
507         VTM = VectorTypeModifier::LFixedLog2LMUL2;
508         break;
509       case 3:
510         VTM = VectorTypeModifier::LFixedLog2LMUL3;
511         break;
512       default:
513         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
514         return None;
515       }
516     } else if (ComplexTT.first == "SFixedLog2LMUL") {
517       int32_t Log2LMUL;
518       if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
519         llvm_unreachable("Invalid SFixedLog2LMUL value!");
520         return None;
521       }
522       switch (Log2LMUL) {
523       case -3:
524         VTM = VectorTypeModifier::SFixedLog2LMULN3;
525         break;
526       case -2:
527         VTM = VectorTypeModifier::SFixedLog2LMULN2;
528         break;
529       case -1:
530         VTM = VectorTypeModifier::SFixedLog2LMULN1;
531         break;
532       case 0:
533         VTM = VectorTypeModifier::SFixedLog2LMUL0;
534         break;
535       case 1:
536         VTM = VectorTypeModifier::SFixedLog2LMUL1;
537         break;
538       case 2:
539         VTM = VectorTypeModifier::SFixedLog2LMUL2;
540         break;
541       case 3:
542         VTM = VectorTypeModifier::SFixedLog2LMUL3;
543         break;
544       default:
545         llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
546         return None;
547       }
548 
549     } else {
550       llvm_unreachable("Illegal complex type transformers!");
551     }
552   }
553   PD.VTM = static_cast<uint8_t>(VTM);
554 
555   // Compute the remain type transformers
556   TypeModifier TM = TypeModifier::NoModifier;
557   for (char I : PrototypeDescriptorStr) {
558     switch (I) {
559     case 'P':
560       if ((TM & TypeModifier::Const) == TypeModifier::Const)
561         llvm_unreachable("'P' transformer cannot be used after 'C'");
562       if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
563         llvm_unreachable("'P' transformer cannot be used twice");
564       TM |= TypeModifier::Pointer;
565       break;
566     case 'C':
567       TM |= TypeModifier::Const;
568       break;
569     case 'K':
570       TM |= TypeModifier::Immediate;
571       break;
572     case 'U':
573       TM |= TypeModifier::UnsignedInteger;
574       break;
575     case 'I':
576       TM |= TypeModifier::SignedInteger;
577       break;
578     case 'F':
579       TM |= TypeModifier::Float;
580       break;
581     case 'S':
582       TM |= TypeModifier::LMUL1;
583       break;
584     default:
585       llvm_unreachable("Illegal non-primitive type transformer!");
586     }
587   }
588   PD.TM = static_cast<uint8_t>(TM);
589 
590   return PD;
591 }
592 
593 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
594   // Handle primitive type transformer
595   switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
596   case BaseTypeModifier::Scalar:
597     Scale = 0;
598     break;
599   case BaseTypeModifier::Vector:
600     Scale = LMUL.getScale(ElementBitwidth);
601     break;
602   case BaseTypeModifier::Void:
603     ScalarType = ScalarTypeKind::Void;
604     break;
605   case BaseTypeModifier::SizeT:
606     ScalarType = ScalarTypeKind::Size_t;
607     break;
608   case BaseTypeModifier::Ptrdiff:
609     ScalarType = ScalarTypeKind::Ptrdiff_t;
610     break;
611   case BaseTypeModifier::UnsignedLong:
612     ScalarType = ScalarTypeKind::UnsignedLong;
613     break;
614   case BaseTypeModifier::SignedLong:
615     ScalarType = ScalarTypeKind::SignedLong;
616     break;
617   case BaseTypeModifier::Invalid:
618     ScalarType = ScalarTypeKind::Invalid;
619     return;
620   }
621 
622   switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
623   case VectorTypeModifier::Widening2XVector:
624     ElementBitwidth *= 2;
625     LMUL.MulLog2LMUL(1);
626     Scale = LMUL.getScale(ElementBitwidth);
627     break;
628   case VectorTypeModifier::Widening4XVector:
629     ElementBitwidth *= 4;
630     LMUL.MulLog2LMUL(2);
631     Scale = LMUL.getScale(ElementBitwidth);
632     break;
633   case VectorTypeModifier::Widening8XVector:
634     ElementBitwidth *= 8;
635     LMUL.MulLog2LMUL(3);
636     Scale = LMUL.getScale(ElementBitwidth);
637     break;
638   case VectorTypeModifier::MaskVector:
639     ScalarType = ScalarTypeKind::Boolean;
640     Scale = LMUL.getScale(ElementBitwidth);
641     ElementBitwidth = 1;
642     break;
643   case VectorTypeModifier::Log2EEW3:
644     applyLog2EEW(3);
645     break;
646   case VectorTypeModifier::Log2EEW4:
647     applyLog2EEW(4);
648     break;
649   case VectorTypeModifier::Log2EEW5:
650     applyLog2EEW(5);
651     break;
652   case VectorTypeModifier::Log2EEW6:
653     applyLog2EEW(6);
654     break;
655   case VectorTypeModifier::FixedSEW8:
656     applyFixedSEW(8);
657     break;
658   case VectorTypeModifier::FixedSEW16:
659     applyFixedSEW(16);
660     break;
661   case VectorTypeModifier::FixedSEW32:
662     applyFixedSEW(32);
663     break;
664   case VectorTypeModifier::FixedSEW64:
665     applyFixedSEW(64);
666     break;
667   case VectorTypeModifier::LFixedLog2LMULN3:
668     applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
669     break;
670   case VectorTypeModifier::LFixedLog2LMULN2:
671     applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
672     break;
673   case VectorTypeModifier::LFixedLog2LMULN1:
674     applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
675     break;
676   case VectorTypeModifier::LFixedLog2LMUL0:
677     applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
678     break;
679   case VectorTypeModifier::LFixedLog2LMUL1:
680     applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
681     break;
682   case VectorTypeModifier::LFixedLog2LMUL2:
683     applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
684     break;
685   case VectorTypeModifier::LFixedLog2LMUL3:
686     applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
687     break;
688   case VectorTypeModifier::SFixedLog2LMULN3:
689     applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
690     break;
691   case VectorTypeModifier::SFixedLog2LMULN2:
692     applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
693     break;
694   case VectorTypeModifier::SFixedLog2LMULN1:
695     applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
696     break;
697   case VectorTypeModifier::SFixedLog2LMUL0:
698     applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
699     break;
700   case VectorTypeModifier::SFixedLog2LMUL1:
701     applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
702     break;
703   case VectorTypeModifier::SFixedLog2LMUL2:
704     applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
705     break;
706   case VectorTypeModifier::SFixedLog2LMUL3:
707     applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
708     break;
709   case VectorTypeModifier::NoModifier:
710     break;
711   }
712 
713   for (unsigned TypeModifierMaskShift = 0;
714        TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
715        ++TypeModifierMaskShift) {
716     unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
717     if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
718         TypeModifierMask)
719       continue;
720     switch (static_cast<TypeModifier>(TypeModifierMask)) {
721     case TypeModifier::Pointer:
722       IsPointer = true;
723       break;
724     case TypeModifier::Const:
725       IsConstant = true;
726       break;
727     case TypeModifier::Immediate:
728       IsImmediate = true;
729       IsConstant = true;
730       break;
731     case TypeModifier::UnsignedInteger:
732       ScalarType = ScalarTypeKind::UnsignedInteger;
733       break;
734     case TypeModifier::SignedInteger:
735       ScalarType = ScalarTypeKind::SignedInteger;
736       break;
737     case TypeModifier::Float:
738       ScalarType = ScalarTypeKind::Float;
739       break;
740     case TypeModifier::LMUL1:
741       LMUL = LMULType(0);
742       // Update ElementBitwidth need to update Scale too.
743       Scale = LMUL.getScale(ElementBitwidth);
744       break;
745     default:
746       llvm_unreachable("Unknown type modifier mask!");
747     }
748   }
749 }
750 
751 void RVVType::applyLog2EEW(unsigned Log2EEW) {
752   // update new elmul = (eew/sew) * lmul
753   LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
754   // update new eew
755   ElementBitwidth = 1 << Log2EEW;
756   ScalarType = ScalarTypeKind::SignedInteger;
757   Scale = LMUL.getScale(ElementBitwidth);
758 }
759 
760 void RVVType::applyFixedSEW(unsigned NewSEW) {
761   // Set invalid type if src and dst SEW are same.
762   if (ElementBitwidth == NewSEW) {
763     ScalarType = ScalarTypeKind::Invalid;
764     return;
765   }
766   // Update new SEW
767   ElementBitwidth = NewSEW;
768   Scale = LMUL.getScale(ElementBitwidth);
769 }
770 
771 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
772   switch (Type) {
773   case FixedLMULType::LargerThan:
774     if (Log2LMUL < LMUL.Log2LMUL) {
775       ScalarType = ScalarTypeKind::Invalid;
776       return;
777     }
778     break;
779   case FixedLMULType::SmallerThan:
780     if (Log2LMUL > LMUL.Log2LMUL) {
781       ScalarType = ScalarTypeKind::Invalid;
782       return;
783     }
784     break;
785   }
786 
787   // Update new LMUL
788   LMUL = LMULType(Log2LMUL);
789   Scale = LMUL.getScale(ElementBitwidth);
790 }
791 
792 Optional<RVVTypes>
793 RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
794                       ArrayRef<PrototypeDescriptor> Prototype) {
795   // LMUL x NF must be less than or equal to 8.
796   if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
797     return llvm::None;
798 
799   RVVTypes Types;
800   for (const PrototypeDescriptor &Proto : Prototype) {
801     auto T = computeType(BT, Log2LMUL, Proto);
802     if (!T.hasValue())
803       return llvm::None;
804     // Record legal type index
805     Types.push_back(T.getValue());
806   }
807   return Types;
808 }
809 
810 // Compute the hash value of RVVType, used for cache the result of computeType.
811 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
812                                         PrototypeDescriptor Proto) {
813   // Layout of hash value:
814   // 0               8    16          24        32          40
815   // | Log2LMUL + 3  | BT  | Proto.PT | Proto.TM | Proto.VTM |
816   assert(Log2LMUL >= -3 && Log2LMUL <= 3);
817   return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
818          ((uint64_t)(Proto.PT & 0xff) << 16) |
819          ((uint64_t)(Proto.TM & 0xff) << 24) |
820          ((uint64_t)(Proto.VTM & 0xff) << 32);
821 }
822 
823 Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
824                                           PrototypeDescriptor Proto) {
825   uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
826   // Search first
827   auto It = LegalTypes.find(Idx);
828   if (It != LegalTypes.end())
829     return &(It->second);
830 
831   if (IllegalTypes.count(Idx))
832     return llvm::None;
833 
834   // Compute type and record the result.
835   RVVType T(BT, Log2LMUL, Proto);
836   if (T.isValid()) {
837     // Record legal type index and value.
838     LegalTypes.insert({Idx, T});
839     return &(LegalTypes[Idx]);
840   }
841   // Record illegal type index.
842   IllegalTypes.insert(Idx);
843   return llvm::None;
844 }
845 
846 //===----------------------------------------------------------------------===//
847 // RVVIntrinsic implementation
848 //===----------------------------------------------------------------------===//
849 RVVIntrinsic::RVVIntrinsic(
850     StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
851     StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
852     bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
853     bool HasUnMaskedOverloaded, bool HasBuiltinAlias, StringRef ManualCodegen,
854     const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
855     const std::vector<StringRef> &RequiredFeatures, unsigned NF)
856     : IRName(IRName), IsMasked(IsMasked), HasVL(HasVL), Scheme(Scheme),
857       HasUnMaskedOverloaded(HasUnMaskedOverloaded),
858       HasBuiltinAlias(HasBuiltinAlias), ManualCodegen(ManualCodegen.str()),
859       NF(NF) {
860 
861   // Init BuiltinName, Name and OverloadedName
862   BuiltinName = NewName.str();
863   Name = BuiltinName;
864   if (NewOverloadedName.empty())
865     OverloadedName = NewName.split("_").first.str();
866   else
867     OverloadedName = NewOverloadedName.str();
868   if (!Suffix.empty())
869     Name += "_" + Suffix.str();
870   if (!OverloadedSuffix.empty())
871     OverloadedName += "_" + OverloadedSuffix.str();
872   if (IsMasked) {
873     BuiltinName += "_m";
874     Name += "_m";
875   }
876 
877   // Init RISC-V extensions
878   for (const auto &T : OutInTypes) {
879     if (T->isFloatVector(16) || T->isFloat(16))
880       RISCVPredefinedMacros |= RISCVPredefinedMacro::Zvfh;
881     if (T->isFloatVector(32))
882       RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp32;
883     if (T->isFloatVector(64))
884       RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELenFp64;
885     if (T->isVector(64))
886       RISCVPredefinedMacros |= RISCVPredefinedMacro::VectorMaxELen64;
887   }
888   for (auto Feature : RequiredFeatures) {
889     if (Feature == "RV64")
890       RISCVPredefinedMacros |= RISCVPredefinedMacro::RV64;
891     // Note: Full multiply instruction (mulh, mulhu, mulhsu, smul) for EEW=64
892     // require V.
893     if (Feature == "FullMultiply" &&
894         (RISCVPredefinedMacros & RISCVPredefinedMacro::VectorMaxELen64))
895       RISCVPredefinedMacros |= RISCVPredefinedMacro::V;
896   }
897 
898   // Init OutputType and InputTypes
899   OutputType = OutInTypes[0];
900   InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
901 
902   // IntrinsicTypes is unmasked TA version index. Need to update it
903   // if there is merge operand (It is always in first operand).
904   IntrinsicTypes = NewIntrinsicTypes;
905   if ((IsMasked && HasMaskedOffOperand) ||
906       (!IsMasked && hasPassthruOperand())) {
907     for (auto &I : IntrinsicTypes) {
908       if (I >= 0)
909         I += NF;
910     }
911   }
912 }
913 
914 std::string RVVIntrinsic::getBuiltinTypeStr() const {
915   std::string S;
916   S += OutputType->getBuiltinStr();
917   for (const auto &T : InputTypes) {
918     S += T->getBuiltinStr();
919   }
920   return S;
921 }
922 
923 std::string RVVIntrinsic::getSuffixStr(
924     BasicType Type, int Log2LMUL,
925     llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
926   SmallVector<std::string> SuffixStrs;
927   for (auto PD : PrototypeDescriptors) {
928     auto T = RVVType::computeType(Type, Log2LMUL, PD);
929     SuffixStrs.push_back(T.getValue()->getShortStr());
930   }
931   return join(SuffixStrs, "_");
932 }
933 
934 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
935   SmallVector<PrototypeDescriptor> PrototypeDescriptors;
936   const StringRef Primaries("evwqom0ztul");
937   while (!Prototypes.empty()) {
938     size_t Idx = 0;
939     // Skip over complex prototype because it could contain primitive type
940     // character.
941     if (Prototypes[0] == '(')
942       Idx = Prototypes.find_first_of(')');
943     Idx = Prototypes.find_first_of(Primaries, Idx);
944     assert(Idx != StringRef::npos);
945     auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
946         Prototypes.slice(0, Idx + 1));
947     if (!PD)
948       llvm_unreachable("Error during parsing prototype.");
949     PrototypeDescriptors.push_back(*PD);
950     Prototypes = Prototypes.drop_front(Idx + 1);
951   }
952   return PrototypeDescriptors;
953 }
954 
955 } // end namespace RISCV
956 } // end namespace clang
957