1 //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file implements name lookup for RISC-V vector intrinsic.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Decl.h"
15 #include "clang/Basic/Builtins.h"
16 #include "clang/Basic/TargetInfo.h"
17 #include "clang/Lex/Preprocessor.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/RISCVIntrinsicManager.h"
20 #include "clang/Sema/Sema.h"
21 #include "clang/Support/RISCVVIntrinsicUtils.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include <string>
24 #include <vector>
25 
26 using namespace llvm;
27 using namespace clang;
28 using namespace clang::RISCV;
29 
30 namespace {
31 
32 // Function definition of a RVV intrinsic.
33 struct RVVIntrinsicDef {
34   /// Full function name with suffix, e.g. vadd_vv_i32m1.
35   std::string Name;
36 
37   /// Overloaded function name, e.g. vadd.
38   std::string OverloadName;
39 
40   /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd.
41   std::string BuiltinName;
42 
43   /// Function signature, first element is return type.
44   RVVTypes Signature;
45 };
46 
47 struct RVVOverloadIntrinsicDef {
48   // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList.
49   SmallVector<size_t, 8> Indexes;
50 };
51 
52 } // namespace
53 
54 static const PrototypeDescriptor RVVSignatureTable[] = {
55 #define DECL_SIGNATURE_TABLE
56 #include "clang/Basic/riscv_vector_builtin_sema.inc"
57 #undef DECL_SIGNATURE_TABLE
58 };
59 
60 static const RVVIntrinsicRecord RVVIntrinsicRecords[] = {
61 #define DECL_INTRINSIC_RECORDS
62 #include "clang/Basic/riscv_vector_builtin_sema.inc"
63 #undef DECL_INTRINSIC_RECORDS
64 };
65 
66 // Get subsequence of signature table.
67 static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index,
68                                                        uint8_t Length) {
69   return makeArrayRef(&RVVSignatureTable[Index], Length);
70 }
71 
72 static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
73   QualType QT;
74   switch (Type->getScalarType()) {
75   case ScalarTypeKind::Void:
76     QT = Context.VoidTy;
77     break;
78   case ScalarTypeKind::Size_t:
79     QT = Context.getSizeType();
80     break;
81   case ScalarTypeKind::Ptrdiff_t:
82     QT = Context.getPointerDiffType();
83     break;
84   case ScalarTypeKind::UnsignedLong:
85     QT = Context.UnsignedLongTy;
86     break;
87   case ScalarTypeKind::SignedLong:
88     QT = Context.LongTy;
89     break;
90   case ScalarTypeKind::Boolean:
91     QT = Context.BoolTy;
92     break;
93   case ScalarTypeKind::SignedInteger:
94     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true);
95     break;
96   case ScalarTypeKind::UnsignedInteger:
97     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
98     break;
99   case ScalarTypeKind::Float:
100     switch (Type->getElementBitwidth()) {
101     case 64:
102       QT = Context.DoubleTy;
103       break;
104     case 32:
105       QT = Context.FloatTy;
106       break;
107     case 16:
108       QT = Context.Float16Ty;
109       break;
110     default:
111       llvm_unreachable("Unsupported floating point width.");
112     }
113     break;
114   case Invalid:
115     llvm_unreachable("Unhandled type.");
116   }
117   if (Type->isVector())
118     QT = Context.getScalableVectorType(QT, Type->getScale().getValue());
119 
120   if (Type->isConstant())
121     QT = Context.getConstType(QT);
122 
123   // Transform the type to a pointer as the last step, if necessary.
124   if (Type->isPointer())
125     QT = Context.getPointerType(QT);
126 
127   return QT;
128 }
129 
130 namespace {
131 class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager {
132 private:
133   Sema &S;
134   ASTContext &Context;
135 
136   // List of all RVV intrinsic.
137   std::vector<RVVIntrinsicDef> IntrinsicList;
138   // Mapping function name to index of IntrinsicList.
139   StringMap<size_t> Intrinsics;
140   // Mapping function name to RVVOverloadIntrinsicDef.
141   StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics;
142 
143   // Create IntrinsicList
144   void InitIntrinsicList();
145 
146   // Create RVVIntrinsicDef.
147   void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr,
148                         StringRef OverloadedSuffixStr, bool IsMask,
149                         RVVTypes &Types);
150 
151   // Create FunctionDecl for a vector intrinsic.
152   void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II,
153                               Preprocessor &PP, unsigned Index,
154                               bool IsOverload);
155 
156 public:
157   RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) {
158     InitIntrinsicList();
159   }
160 
161   // Create RISC-V vector intrinsic and insert into symbol table if found, and
162   // return true, otherwise return false.
163   bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II,
164                               Preprocessor &PP) override;
165 };
166 } // namespace
167 
168 void RISCVIntrinsicManagerImpl::InitIntrinsicList() {
169   const TargetInfo &TI = Context.getTargetInfo();
170   bool HasVectorFloat32 = TI.hasFeature("zve32f");
171   bool HasVectorFloat64 = TI.hasFeature("zve64d");
172   bool HasZvfh = TI.hasFeature("experimental-zvfh");
173   bool HasRV64 = TI.hasFeature("64bit");
174   bool HasFullMultiply = TI.hasFeature("v");
175 
176   // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics
177   // in RISCVVEmitter.cpp.
178   for (auto &Record : RVVIntrinsicRecords) {
179     // Create Intrinsics for each type and LMUL.
180     BasicType BaseType = BasicType::Unknown;
181     ArrayRef<PrototypeDescriptor> ProtoSeq =
182         ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength);
183     ArrayRef<PrototypeDescriptor> ProtoMaskSeq = ProtoSeq2ArrayRef(
184         Record.MaskedPrototypeIndex, Record.MaskedPrototypeLength);
185     ArrayRef<PrototypeDescriptor> SuffixProto =
186         ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength);
187     ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef(
188         Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize);
189     for (unsigned int TypeRangeMaskShift = 0;
190          TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset);
191          ++TypeRangeMaskShift) {
192       unsigned int BaseTypeI = 1 << TypeRangeMaskShift;
193       BaseType = static_cast<BasicType>(BaseTypeI);
194 
195       if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI)
196         continue;
197 
198       // Check requirement.
199       if (BaseType == BasicType::Float16 && !HasZvfh)
200         continue;
201 
202       if (BaseType == BasicType::Float32 && !HasVectorFloat32)
203         continue;
204 
205       if (BaseType == BasicType::Float64 && !HasVectorFloat64)
206         continue;
207 
208       if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) &&
209           !HasRV64)
210         continue;
211 
212       if ((BaseType == BasicType::Int64) &&
213           ((Record.RequiredExtensions & RVV_REQ_FullMultiply) ==
214            RVV_REQ_FullMultiply) &&
215           !HasFullMultiply)
216         continue;
217 
218       // Expanded with different LMUL.
219       for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) {
220         if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3))))
221           continue;
222 
223         Optional<RVVTypes> Types =
224             RVVType::computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq);
225 
226         // Ignored to create new intrinsic if there are any illegal types.
227         if (!Types.hasValue())
228           continue;
229 
230         std::string SuffixStr =
231             RVVIntrinsic::getSuffixStr(BaseType, Log2LMUL, SuffixProto);
232         std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr(
233             BaseType, Log2LMUL, OverloadedSuffixProto);
234 
235         // Create non-masked intrinsic.
236         InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types);
237 
238         if (Record.MaskedPrototypeLength != 0) {
239           // Create masked intrinsic.
240           Optional<RVVTypes> MaskTypes = RVVType::computeTypes(
241               BaseType, Log2LMUL, Record.NF, ProtoMaskSeq);
242 
243           InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true,
244                            *MaskTypes);
245         }
246       }
247     }
248   }
249 }
250 
251 // Compute name and signatures for intrinsic with practical types.
252 void RISCVIntrinsicManagerImpl::InitRVVIntrinsic(
253     const RVVIntrinsicRecord &Record, StringRef SuffixStr,
254     StringRef OverloadedSuffixStr, bool IsMask, RVVTypes &Signature) {
255   // Function name, e.g. vadd_vv_i32m1.
256   std::string Name = Record.Name;
257   if (!SuffixStr.empty())
258     Name += "_" + SuffixStr.str();
259 
260   if (IsMask)
261     Name += "_m";
262 
263   // Overloaded function name, e.g. vadd.
264   std::string OverloadedName;
265   if (!Record.OverloadedName)
266     OverloadedName = StringRef(Record.Name).split("_").first.str();
267   else
268     OverloadedName = Record.OverloadedName;
269   if (!OverloadedSuffixStr.empty())
270     OverloadedName += "_" + OverloadedSuffixStr.str();
271 
272   // clang built-in function name, e.g. __builtin_rvv_vadd.
273   std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
274   if (IsMask)
275     BuiltinName += "_m";
276 
277   // Put into IntrinsicList.
278   size_t Index = IntrinsicList.size();
279   IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature});
280 
281   // Creating mapping to Intrinsics.
282   Intrinsics.insert({Name, Index});
283 
284   // Get the RVVOverloadIntrinsicDef.
285   RVVOverloadIntrinsicDef &OverloadIntrinsicDef =
286       OverloadIntrinsics[OverloadedName];
287 
288   // And added the index.
289   OverloadIntrinsicDef.Indexes.push_back(Index);
290 }
291 
292 void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR,
293                                                        IdentifierInfo *II,
294                                                        Preprocessor &PP,
295                                                        unsigned Index,
296                                                        bool IsOverload) {
297   ASTContext &Context = S.Context;
298   RVVIntrinsicDef &IDef = IntrinsicList[Index];
299   RVVTypes Sigs = IDef.Signature;
300   size_t SigLength = Sigs.size();
301   RVVType *ReturnType = Sigs[0];
302   QualType RetType = RVVType2Qual(Context, ReturnType);
303   SmallVector<QualType, 8> ArgTypes;
304   QualType BuiltinFuncType;
305 
306   // Skip return type, and convert RVVType to QualType for arguments.
307   for (size_t i = 1; i < SigLength; ++i)
308     ArgTypes.push_back(RVVType2Qual(Context, Sigs[i]));
309 
310   FunctionProtoType::ExtProtoInfo PI(
311       Context.getDefaultCallingConvention(false, false, true));
312 
313   PI.Variadic = false;
314 
315   SourceLocation Loc = LR.getNameLoc();
316   BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI);
317   DeclContext *Parent = Context.getTranslationUnitDecl();
318 
319   FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create(
320       Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr,
321       SC_Extern, S.getCurFPFeatures().isFPConstrained(),
322       /*isInlineSpecified*/ false,
323       /*hasWrittenPrototype*/ true);
324 
325   // Create Decl objects for each parameter, adding them to the
326   // FunctionDecl.
327   const auto *FP = cast<FunctionProtoType>(BuiltinFuncType);
328   SmallVector<ParmVarDecl *, 8> ParmList;
329   for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) {
330     ParmVarDecl *Parm =
331         ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr,
332                             FP->getParamType(IParm), nullptr, SC_None, nullptr);
333     Parm->setScopeInfo(0, IParm);
334     ParmList.push_back(Parm);
335   }
336   RVVIntrinsicDecl->setParams(ParmList);
337 
338   // Add function attributes.
339   if (IsOverload)
340     RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context));
341 
342   // Setup alias to __builtin_rvv_*
343   IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName);
344   RVVIntrinsicDecl->addAttr(
345       BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII));
346 
347   // Add to symbol table.
348   LR.addDecl(RVVIntrinsicDecl);
349 }
350 
351 bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR,
352                                                        IdentifierInfo *II,
353                                                        Preprocessor &PP) {
354   StringRef Name = II->getName();
355 
356   // Lookup the function name from the overload intrinsics first.
357   auto OvIItr = OverloadIntrinsics.find(Name);
358   if (OvIItr != OverloadIntrinsics.end()) {
359     const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second;
360     for (auto Index : OvIntrinsicDef.Indexes)
361       CreateRVVIntrinsicDecl(LR, II, PP, Index,
362                              /*IsOverload*/ true);
363 
364     // If we added overloads, need to resolve the lookup result.
365     LR.resolveKind();
366     return true;
367   }
368 
369   // Lookup the function name from the intrinsics.
370   auto Itr = Intrinsics.find(Name);
371   if (Itr != Intrinsics.end()) {
372     CreateRVVIntrinsicDecl(LR, II, PP, Itr->second,
373                            /*IsOverload*/ false);
374     return true;
375   }
376 
377   // It's not an RVV intrinsics.
378   return false;
379 }
380 
381 namespace clang {
382 std::unique_ptr<clang::sema::RISCVIntrinsicManager>
383 CreateRISCVIntrinsicManager(Sema &S) {
384   return std::make_unique<RISCVIntrinsicManagerImpl>(S);
385 }
386 } // namespace clang
387