1 //===- Type.cpp - Type class ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Type wrapper to simplify using TableGen Record defining a MLIR Type.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Type.h"
14 #include "mlir/TableGen/Dialect.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/TableGen/Record.h"
18 
19 using namespace mlir;
20 using namespace mlir::tblgen;
21 
22 TypeConstraint::TypeConstraint(const llvm::Record *record)
23     : Constraint(Constraint::CK_Type, record) {
24   assert(def->isSubClassOf("TypeConstraint") &&
25          "must be subclass of TableGen 'TypeConstraint' class");
26 }
27 
28 TypeConstraint::TypeConstraint(const llvm::DefInit *init)
29     : TypeConstraint(init->getDef()) {}
30 
31 bool TypeConstraint::isOptional() const {
32   return def->isSubClassOf("Optional");
33 }
34 
35 bool TypeConstraint::isVariadic() const {
36   return def->isSubClassOf("Variadic");
37 }
38 
39 // Returns the builder call for this constraint if this is a buildable type,
40 // returns None otherwise.
41 Optional<StringRef> TypeConstraint::getBuilderCall() const {
42   const llvm::Record *baseType = def;
43   if (isVariableLength())
44     baseType = baseType->getValueAsDef("baseType");
45 
46   // Check to see if this type constraint has a builder call.
47   const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
48   if (!builderCall || !builderCall->getValue())
49     return llvm::None;
50   return TypeSwitch<llvm::Init *, Optional<StringRef>>(builderCall->getValue())
51       .Case<llvm::StringInit>([&](auto *init) {
52         StringRef value = init->getValue();
53         return value.empty() ? Optional<StringRef>() : value;
54       })
55       .Default([](auto *) { return llvm::None; });
56 }
57 
58 // Return the C++ class name for this type (which may just be ::mlir::Type).
59 std::string TypeConstraint::getCPPClassName() const {
60   StringRef className = def->getValueAsString("cppClassName");
61 
62   // If the class name is already namespace resolved, use it.
63   if (className.contains("::"))
64     return className.str();
65 
66   // Otherwise, check to see if there is a namespace from a dialect to prepend.
67   if (const llvm::RecordVal *value = def->getValue("dialect")) {
68     Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
69     return (dialect.getCppNamespace() + "::" + className).str();
70   }
71   return className.str();
72 }
73 
74 Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
75 
76 StringRef Type::getDescription() const {
77   return def->getValueAsString("description");
78 }
79 
80 Dialect Type::getDialect() const {
81   return Dialect(def->getValueAsDef("dialect"));
82 }
83