1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Operator.h" 23 #include "mlir/TableGen/OpTrait.h" 24 #include "mlir/TableGen/Predicate.h" 25 #include "mlir/TableGen/Type.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/TableGen/Error.h" 28 #include "llvm/TableGen/Record.h" 29 30 #define DEBUG_TYPE "mlir-tblgen-operator" 31 32 using namespace mlir; 33 34 using llvm::DagInit; 35 using llvm::DefInit; 36 using llvm::Record; 37 38 tblgen::Operator::Operator(const llvm::Record &def) 39 : dialect(def.getValueAsDef("opDialect")), def(def) { 40 // The first `_` in the op's TableGen def name is treated as separating the 41 // dialect prefix and the op class name. The dialect prefix will be ignored if 42 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 43 // as part of the class name. 44 StringRef prefix; 45 std::tie(prefix, cppClassName) = def.getName().split('_'); 46 if (prefix.empty()) { 47 // Class name with a leading underscore and without dialect prefix 48 cppClassName = def.getName(); 49 } else if (cppClassName.empty()) { 50 // Class name without dialect prefix 51 cppClassName = prefix; 52 } 53 54 populateOpStructure(); 55 } 56 57 std::string tblgen::Operator::getOperationName() const { 58 auto prefix = dialect.getName(); 59 auto opName = def.getValueAsString("opName"); 60 if (prefix.empty()) 61 return opName; 62 return llvm::formatv("{0}.{1}", prefix, opName); 63 } 64 65 StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } 66 67 StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } 68 69 std::string tblgen::Operator::getQualCppClassName() const { 70 auto prefix = dialect.getCppNamespace(); 71 if (prefix.empty()) 72 return cppClassName; 73 return llvm::formatv("{0}::{1}", prefix, cppClassName); 74 } 75 76 int tblgen::Operator::getNumResults() const { 77 DagInit *results = def.getValueAsDag("results"); 78 return results->getNumArgs(); 79 } 80 81 StringRef tblgen::Operator::getExtraClassDeclaration() const { 82 constexpr auto attr = "extraClassDeclaration"; 83 if (def.isValueUnset(attr)) 84 return {}; 85 return def.getValueAsString(attr); 86 } 87 88 const llvm::Record &tblgen::Operator::getDef() const { return def; } 89 90 bool tblgen::Operator::isVariadic() const { 91 return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0; 92 } 93 94 bool tblgen::Operator::skipDefaultBuilders() const { 95 return def.getValueAsBit("skipDefaultBuilders"); 96 } 97 98 auto tblgen::Operator::result_begin() -> value_iterator { 99 return results.begin(); 100 } 101 102 auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } 103 104 auto tblgen::Operator::getResults() -> value_range { 105 return {result_begin(), result_end()}; 106 } 107 108 tblgen::TypeConstraint 109 tblgen::Operator::getResultTypeConstraint(int index) const { 110 DagInit *results = def.getValueAsDag("results"); 111 return TypeConstraint(cast<DefInit>(results->getArg(index))); 112 } 113 114 StringRef tblgen::Operator::getResultName(int index) const { 115 DagInit *results = def.getValueAsDag("results"); 116 return results->getArgNameStr(index); 117 } 118 119 unsigned tblgen::Operator::getNumVariadicResults() const { 120 return std::count_if( 121 results.begin(), results.end(), 122 [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); 123 } 124 125 unsigned tblgen::Operator::getNumVariadicOperands() const { 126 return std::count_if( 127 operands.begin(), operands.end(), 128 [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); 129 } 130 131 tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const { 132 return arguments.begin(); 133 } 134 135 tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const { 136 return arguments.end(); 137 } 138 139 tblgen::Operator::arg_range tblgen::Operator::getArgs() const { 140 return {arg_begin(), arg_end()}; 141 } 142 143 StringRef tblgen::Operator::getArgName(int index) const { 144 DagInit *argumentValues = def.getValueAsDag("arguments"); 145 return argumentValues->getArgName(index)->getValue(); 146 } 147 148 const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { 149 for (const auto &t : traits) { 150 if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) { 151 if (opTrait->getTrait() == trait) 152 return opTrait; 153 } else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) { 154 if (opTrait->getTrait() == trait) 155 return opTrait; 156 } else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&t)) { 157 if (opTrait->getTrait() == trait) 158 return opTrait; 159 } 160 } 161 return nullptr; 162 } 163 164 unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } 165 166 const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { 167 return regions[index]; 168 } 169 170 auto tblgen::Operator::trait_begin() const -> const_trait_iterator { 171 return traits.begin(); 172 } 173 auto tblgen::Operator::trait_end() const -> const_trait_iterator { 174 return traits.end(); 175 } 176 auto tblgen::Operator::getTraits() const 177 -> llvm::iterator_range<const_trait_iterator> { 178 return {trait_begin(), trait_end()}; 179 } 180 181 auto tblgen::Operator::attribute_begin() const -> attribute_iterator { 182 return attributes.begin(); 183 } 184 auto tblgen::Operator::attribute_end() const -> attribute_iterator { 185 return attributes.end(); 186 } 187 auto tblgen::Operator::getAttributes() const 188 -> llvm::iterator_range<attribute_iterator> { 189 return {attribute_begin(), attribute_end()}; 190 } 191 192 auto tblgen::Operator::operand_begin() -> value_iterator { 193 return operands.begin(); 194 } 195 auto tblgen::Operator::operand_end() -> value_iterator { 196 return operands.end(); 197 } 198 auto tblgen::Operator::getOperands() -> value_range { 199 return {operand_begin(), operand_end()}; 200 } 201 202 auto tblgen::Operator::getArg(int index) const -> Argument { 203 return arguments[index]; 204 } 205 206 void tblgen::Operator::populateOpStructure() { 207 auto &recordKeeper = def.getRecords(); 208 auto typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 209 auto attrClass = recordKeeper.getClass("Attr"); 210 auto derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 211 numNativeAttributes = 0; 212 213 DagInit *argumentValues = def.getValueAsDag("arguments"); 214 unsigned numArgs = argumentValues->getNumArgs(); 215 216 // Handle operands and native attributes. 217 for (unsigned i = 0; i != numArgs; ++i) { 218 auto arg = argumentValues->getArg(i); 219 auto givenName = argumentValues->getArgNameStr(i); 220 auto argDefInit = dyn_cast<DefInit>(arg); 221 if (!argDefInit) 222 PrintFatalError(def.getLoc(), 223 Twine("undefined type for argument #") + Twine(i)); 224 Record *argDef = argDefInit->getDef(); 225 226 if (argDef->isSubClassOf(typeConstraintClass)) { 227 operands.push_back( 228 NamedTypeConstraint{givenName, TypeConstraint(argDefInit)}); 229 } else if (argDef->isSubClassOf(attrClass)) { 230 if (givenName.empty()) 231 PrintFatalError(argDef->getLoc(), "attributes must be named"); 232 if (argDef->isSubClassOf(derivedAttrClass)) 233 PrintFatalError(argDef->getLoc(), 234 "derived attributes not allowed in argument list"); 235 attributes.push_back({givenName, Attribute(argDef)}); 236 ++numNativeAttributes; 237 } else { 238 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " 239 "from TypeConstraint or Attr are allowed"); 240 } 241 } 242 243 // Handle derived attributes. 244 for (const auto &val : def.getValues()) { 245 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 246 if (!record->isSubClassOf(attrClass)) 247 continue; 248 if (!record->isSubClassOf(derivedAttrClass)) 249 PrintFatalError(def.getLoc(), 250 "unexpected Attr where only DerivedAttr is allowed"); 251 252 if (record->getClasses().size() != 1) { 253 PrintFatalError( 254 def.getLoc(), 255 "unsupported attribute modelling, only single class expected"); 256 } 257 attributes.push_back( 258 {cast<llvm::StringInit>(val.getNameInit())->getValue(), 259 Attribute(cast<DefInit>(val.getValue()))}); 260 } 261 } 262 263 // Populate `arguments`. This must happen after we've finalized `operands` and 264 // `attributes` because we will put their elements' pointers in `arguments`. 265 // SmallVector may perform re-allocation under the hood when adding new 266 // elements. 267 int operandIndex = 0, attrIndex = 0; 268 for (unsigned i = 0; i != numArgs; ++i) { 269 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef(); 270 271 if (argDef->isSubClassOf(typeConstraintClass)) { 272 arguments.emplace_back(&operands[operandIndex++]); 273 } else { 274 assert(argDef->isSubClassOf(attrClass)); 275 arguments.emplace_back(&attributes[attrIndex++]); 276 } 277 } 278 279 auto *resultsDag = def.getValueAsDag("results"); 280 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 281 if (!outsOp || outsOp->getDef()->getName() != "outs") { 282 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 283 } 284 285 // Handle results. 286 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 287 auto name = resultsDag->getArgNameStr(i); 288 auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i)); 289 if (!resultDef) { 290 PrintFatalError(def.getLoc(), 291 Twine("undefined type for result #") + Twine(i)); 292 } 293 results.push_back({name, TypeConstraint(resultDef)}); 294 } 295 296 auto traitListInit = def.getValueAsListInit("traits"); 297 if (!traitListInit) 298 return; 299 traits.reserve(traitListInit->size()); 300 for (auto traitInit : *traitListInit) 301 traits.push_back(OpTrait::create(traitInit)); 302 303 // Handle regions 304 auto *regionsDag = def.getValueAsDag("regions"); 305 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 306 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 307 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 308 } 309 310 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 311 auto name = regionsDag->getArgNameStr(i); 312 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 313 if (!regionInit) { 314 PrintFatalError(def.getLoc(), 315 Twine("undefined kind for region #") + Twine(i)); 316 } 317 regions.push_back({name, Region(regionInit->getDef())}); 318 } 319 320 LLVM_DEBUG(print(llvm::dbgs())); 321 } 322 323 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); } 324 325 bool tblgen::Operator::hasDescription() const { 326 return def.getValue("description") != nullptr; 327 } 328 329 StringRef tblgen::Operator::getDescription() const { 330 return def.getValueAsString("description"); 331 } 332 333 bool tblgen::Operator::hasSummary() const { 334 return def.getValue("summary") != nullptr; 335 } 336 337 StringRef tblgen::Operator::getSummary() const { 338 return def.getValueAsString("summary"); 339 } 340 341 void tblgen::Operator::print(llvm::raw_ostream &os) const { 342 os << "op '" << getOperationName() << "'\n"; 343 for (Argument arg : arguments) { 344 if (auto *attr = arg.dyn_cast<NamedAttribute *>()) 345 os << "[attribute] " << attr->name << '\n'; 346 else 347 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n'; 348 } 349 } 350