1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Attribute wrapper to simplify using TableGen Record defining a MLIR 10 // Attribute. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/TableGen/Record.h" 17 18 using namespace mlir; 19 using namespace mlir::tblgen; 20 21 using llvm::DefInit; 22 using llvm::Init; 23 using llvm::Record; 24 using llvm::StringInit; 25 26 // Returns the initializer's value as string if the given TableGen initializer 27 // is a code or string initializer. Returns the empty StringRef otherwise. getValueAsString(const Init * init)28static StringRef getValueAsString(const Init *init) { 29 if (const auto *str = dyn_cast<StringInit>(init)) 30 return str->getValue().trim(); 31 return {}; 32 } 33 isSubClassOf(StringRef className) const34bool AttrConstraint::isSubClassOf(StringRef className) const { 35 return def->isSubClassOf(className); 36 } 37 Attribute(const Record * record)38Attribute::Attribute(const Record *record) : AttrConstraint(record) { 39 assert(record->isSubClassOf("Attr") && 40 "must be subclass of TableGen 'Attr' class"); 41 } 42 Attribute(const DefInit * init)43Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} 44 isDerivedAttr() const45bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); } 46 isTypeAttr() const47bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } 48 isSymbolRefAttr() const49bool Attribute::isSymbolRefAttr() const { 50 StringRef defName = def->getName(); 51 if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr") 52 return true; 53 return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr"); 54 } 55 isEnumAttr() const56bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } 57 getStorageType() const58StringRef Attribute::getStorageType() const { 59 const auto *init = def->getValueInit("storageType"); 60 auto type = getValueAsString(init); 61 if (type.empty()) 62 return "::mlir::Attribute"; 63 return type; 64 } 65 getReturnType() const66StringRef Attribute::getReturnType() const { 67 const auto *init = def->getValueInit("returnType"); 68 return getValueAsString(init); 69 } 70 71 // Return the type constraint corresponding to the type of this attribute, or 72 // None if this is not a TypedAttr. getValueType() const73llvm::Optional<Type> Attribute::getValueType() const { 74 if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType"))) 75 return Type(defInit->getDef()); 76 return llvm::None; 77 } 78 getConvertFromStorageCall() const79StringRef Attribute::getConvertFromStorageCall() const { 80 const auto *init = def->getValueInit("convertFromStorage"); 81 return getValueAsString(init); 82 } 83 isConstBuildable() const84bool Attribute::isConstBuildable() const { 85 const auto *init = def->getValueInit("constBuilderCall"); 86 return !getValueAsString(init).empty(); 87 } 88 getConstBuilderTemplate() const89StringRef Attribute::getConstBuilderTemplate() const { 90 const auto *init = def->getValueInit("constBuilderCall"); 91 return getValueAsString(init); 92 } 93 getBaseAttr() const94Attribute Attribute::getBaseAttr() const { 95 if (const auto *defInit = 96 llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) { 97 return Attribute(defInit).getBaseAttr(); 98 } 99 return *this; 100 } 101 hasDefaultValue() const102bool Attribute::hasDefaultValue() const { 103 const auto *init = def->getValueInit("defaultValue"); 104 return !getValueAsString(init).empty(); 105 } 106 getDefaultValue() const107StringRef Attribute::getDefaultValue() const { 108 const auto *init = def->getValueInit("defaultValue"); 109 return getValueAsString(init); 110 } 111 isOptional() const112bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } 113 getAttrDefName() const114StringRef Attribute::getAttrDefName() const { 115 if (def->isAnonymous()) { 116 return getBaseAttr().def->getName(); 117 } 118 return def->getName(); 119 } 120 getDerivedCodeBody() const121StringRef Attribute::getDerivedCodeBody() const { 122 assert(isDerivedAttr() && "only derived attribute has 'body' field"); 123 return def->getValueAsString("body"); 124 } 125 getDialect() const126Dialect Attribute::getDialect() const { 127 const llvm::RecordVal *record = def->getValue("dialect"); 128 if (record && record->getValue()) { 129 if (DefInit *init = dyn_cast<DefInit>(record->getValue())) 130 return Dialect(init->getDef()); 131 } 132 return Dialect(nullptr); 133 } 134 ConstantAttr(const DefInit * init)135ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { 136 assert(def->isSubClassOf("ConstantAttr") && 137 "must be subclass of TableGen 'ConstantAttr' class"); 138 } 139 getAttribute() const140Attribute ConstantAttr::getAttribute() const { 141 return Attribute(def->getValueAsDef("attr")); 142 } 143 getConstantValue() const144StringRef ConstantAttr::getConstantValue() const { 145 return def->getValueAsString("value"); 146 } 147 EnumAttrCase(const llvm::Record * record)148EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { 149 assert(isSubClassOf("EnumAttrCaseInfo") && 150 "must be subclass of TableGen 'EnumAttrInfo' class"); 151 } 152 EnumAttrCase(const llvm::DefInit * init)153EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) 154 : EnumAttrCase(init->getDef()) {} 155 getSymbol() const156StringRef EnumAttrCase::getSymbol() const { 157 return def->getValueAsString("symbol"); 158 } 159 getStr() const160StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } 161 getValue() const162int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } 163 getDef() const164const llvm::Record &EnumAttrCase::getDef() const { return *def; } 165 EnumAttr(const llvm::Record * record)166EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { 167 assert(isSubClassOf("EnumAttrInfo") && 168 "must be subclass of TableGen 'EnumAttr' class"); 169 } 170 EnumAttr(const llvm::Record & record)171EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} 172 EnumAttr(const llvm::DefInit * init)173EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} 174 classof(const Attribute * attr)175bool EnumAttr::classof(const Attribute *attr) { 176 return attr->isSubClassOf("EnumAttrInfo"); 177 } 178 isBitEnum() const179bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } 180 getEnumClassName() const181StringRef EnumAttr::getEnumClassName() const { 182 return def->getValueAsString("className"); 183 } 184 getCppNamespace() const185StringRef EnumAttr::getCppNamespace() const { 186 return def->getValueAsString("cppNamespace"); 187 } 188 getUnderlyingType() const189StringRef EnumAttr::getUnderlyingType() const { 190 return def->getValueAsString("underlyingType"); 191 } 192 getUnderlyingToSymbolFnName() const193StringRef EnumAttr::getUnderlyingToSymbolFnName() const { 194 return def->getValueAsString("underlyingToSymbolFnName"); 195 } 196 getStringToSymbolFnName() const197StringRef EnumAttr::getStringToSymbolFnName() const { 198 return def->getValueAsString("stringToSymbolFnName"); 199 } 200 getSymbolToStringFnName() const201StringRef EnumAttr::getSymbolToStringFnName() const { 202 return def->getValueAsString("symbolToStringFnName"); 203 } 204 getSymbolToStringFnRetType() const205StringRef EnumAttr::getSymbolToStringFnRetType() const { 206 return def->getValueAsString("symbolToStringFnRetType"); 207 } 208 getMaxEnumValFnName() const209StringRef EnumAttr::getMaxEnumValFnName() const { 210 return def->getValueAsString("maxEnumValFnName"); 211 } 212 getAllCases() const213std::vector<EnumAttrCase> EnumAttr::getAllCases() const { 214 const auto *inits = def->getValueAsListInit("enumerants"); 215 216 std::vector<EnumAttrCase> cases; 217 cases.reserve(inits->size()); 218 219 for (const llvm::Init *init : *inits) { 220 cases.emplace_back(cast<llvm::DefInit>(init)); 221 } 222 223 return cases; 224 } 225 genSpecializedAttr() const226bool EnumAttr::genSpecializedAttr() const { 227 return def->getValueAsBit("genSpecializedAttr"); 228 } 229 getBaseAttrClass() const230llvm::Record *EnumAttr::getBaseAttrClass() const { 231 return def->getValueAsDef("baseAttrClass"); 232 } 233 getSpecializedAttrClassName() const234StringRef EnumAttr::getSpecializedAttrClassName() const { 235 return def->getValueAsString("specializedAttrClassName"); 236 } 237 printBitEnumPrimaryGroups() const238bool EnumAttr::printBitEnumPrimaryGroups() const { 239 return def->getValueAsBit("printBitEnumPrimaryGroups"); 240 } 241 242 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; 243