1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/ExtensibleDialect.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Support/LogicalResult.h" 21 #include "llvm/ADT/Hashing.h" 22 #include "llvm/ADT/SetVector.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/ErrorHandling.h" 26 27 using namespace mlir; 28 using namespace test; 29 30 //===----------------------------------------------------------------------===// 31 // AttrWithTypeBuilderAttr 32 //===----------------------------------------------------------------------===// 33 34 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) { 35 IntegerAttr element; 36 if (parser.parseAttribute(element)) 37 return Attribute(); 38 return get(parser.getContext(), element); 39 } 40 41 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const { 42 printer << " " << getAttr(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // CompoundAAttr 47 //===----------------------------------------------------------------------===// 48 49 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { 50 int widthOfSomething; 51 Type oneType; 52 SmallVector<int, 4> arrayOfInts; 53 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 54 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 55 parser.parseLSquare()) 56 return Attribute(); 57 58 int intVal; 59 while (!*parser.parseOptionalInteger(intVal)) { 60 arrayOfInts.push_back(intVal); 61 if (parser.parseOptionalComma()) 62 break; 63 } 64 65 if (parser.parseRSquare() || parser.parseGreater()) 66 return Attribute(); 67 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 68 } 69 70 void CompoundAAttr::print(AsmPrinter &printer) const { 71 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 72 llvm::interleaveComma(getArrayOfInts(), printer); 73 printer << "]>"; 74 } 75 76 //===----------------------------------------------------------------------===// 77 // CompoundAAttr 78 //===----------------------------------------------------------------------===// 79 80 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { 81 SmallVector<uint64_t> elements; 82 if (parser.parseLess() || parser.parseLSquare()) 83 return Attribute(); 84 uint64_t intVal; 85 while (succeeded(*parser.parseOptionalInteger(intVal))) { 86 elements.push_back(intVal); 87 if (parser.parseOptionalComma()) 88 break; 89 } 90 91 if (parser.parseRSquare() || parser.parseGreater()) 92 return Attribute(); 93 return parser.getChecked<TestI64ElementsAttr>( 94 parser.getContext(), type.cast<ShapedType>(), elements); 95 } 96 97 void TestI64ElementsAttr::print(AsmPrinter &printer) const { 98 printer << "<["; 99 llvm::interleaveComma(getElements(), printer); 100 printer << "] : " << getType() << ">"; 101 } 102 103 LogicalResult 104 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 105 ShapedType type, ArrayRef<uint64_t> elements) { 106 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 107 return emitError() 108 << "number of elements does not match the provided shape type, got: " 109 << elements.size() << ", but expected: " << type.getNumElements(); 110 } 111 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 112 return emitError() << "expected single rank 64-bit shape type, but got: " 113 << type; 114 return success(); 115 } 116 117 LogicalResult TestAttrWithFormatAttr::verify( 118 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two, 119 IntegerAttr three, ArrayRef<int> four, 120 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) { 121 if (four.size() != static_cast<unsigned>(one)) 122 return emitError() << "expected 'one' to equal 'four.size()'"; 123 return success(); 124 } 125 126 //===----------------------------------------------------------------------===// 127 // Utility Functions for Generated Attributes 128 //===----------------------------------------------------------------------===// 129 130 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { 131 SmallVector<int> ints; 132 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { 133 ints.push_back(0); 134 return parser.parseInteger(ints.back()); 135 }) || 136 parser.parseRSquare()) 137 return failure(); 138 return ints; 139 } 140 141 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { 142 printer << '['; 143 llvm::interleaveComma(ints, printer); 144 printer << ']'; 145 } 146 147 //===----------------------------------------------------------------------===// 148 // TestSubElementsAccessAttr 149 //===----------------------------------------------------------------------===// 150 151 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, 152 ::mlir::Type type) { 153 Attribute first, second, third; 154 if (parser.parseLess() || parser.parseAttribute(first) || 155 parser.parseComma() || parser.parseAttribute(second) || 156 parser.parseComma() || parser.parseAttribute(third) || 157 parser.parseGreater()) { 158 return {}; 159 } 160 return get(parser.getContext(), first, second, third); 161 } 162 163 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { 164 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() 165 << ">"; 166 } 167 168 void TestSubElementsAccessAttr::walkImmediateSubElements( 169 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn, 170 llvm::function_ref<void(mlir::Type)> walkTypesFn) const { 171 walkAttrsFn(getFirst()); 172 walkAttrsFn(getSecond()); 173 walkAttrsFn(getThird()); 174 } 175 176 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( 177 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 178 Attribute first = getFirst(); 179 Attribute second = getSecond(); 180 Attribute third = getThird(); 181 for (auto &it : replacements) { 182 switch (it.first) { 183 case 0: 184 first = it.second; 185 break; 186 case 1: 187 second = it.second; 188 break; 189 case 2: 190 third = it.second; 191 break; 192 } 193 } 194 return get(getContext(), first, second, third); 195 } 196 197 //===----------------------------------------------------------------------===// 198 // TestExtern1DI64ElementsAttr 199 //===----------------------------------------------------------------------===// 200 201 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const { 202 return getHandle().getData()->getData(); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // Tablegen Generated Definitions 207 //===----------------------------------------------------------------------===// 208 209 #include "TestAttrInterfaces.cpp.inc" 210 211 #define GET_ATTRDEF_CLASSES 212 #include "TestAttrDefs.cpp.inc" 213 214 //===----------------------------------------------------------------------===// 215 // Dynamic Attributes 216 //===----------------------------------------------------------------------===// 217 218 /// Define a singleton dynamic attribute. 219 static std::unique_ptr<DynamicAttrDefinition> 220 getDynamicSingletonAttr(TestDialect *testDialect) { 221 return DynamicAttrDefinition::get( 222 "dynamic_singleton", testDialect, 223 [](function_ref<InFlightDiagnostic()> emitError, 224 ArrayRef<Attribute> args) { 225 if (!args.empty()) { 226 emitError() << "expected 0 attribute arguments, but had " 227 << args.size(); 228 return failure(); 229 } 230 return success(); 231 }); 232 } 233 234 /// Define a dynamic attribute representing a pair or attributes. 235 static std::unique_ptr<DynamicAttrDefinition> 236 getDynamicPairAttr(TestDialect *testDialect) { 237 return DynamicAttrDefinition::get( 238 "dynamic_pair", testDialect, 239 [](function_ref<InFlightDiagnostic()> emitError, 240 ArrayRef<Attribute> args) { 241 if (args.size() != 2) { 242 emitError() << "expected 2 attribute arguments, but had " 243 << args.size(); 244 return failure(); 245 } 246 return success(); 247 }); 248 } 249 250 static std::unique_ptr<DynamicAttrDefinition> 251 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { 252 auto verifier = [](function_ref<InFlightDiagnostic()> emitError, 253 ArrayRef<Attribute> args) { 254 if (args.size() != 2) { 255 emitError() << "expected 2 attribute arguments, but had " << args.size(); 256 return failure(); 257 } 258 return success(); 259 }; 260 261 auto parser = [](AsmParser &parser, 262 llvm::SmallVectorImpl<Attribute> &parsedParams) { 263 Attribute leftAttr, rightAttr; 264 if (parser.parseLess() || parser.parseAttribute(leftAttr) || 265 parser.parseColon() || parser.parseAttribute(rightAttr) || 266 parser.parseGreater()) 267 return failure(); 268 parsedParams.push_back(leftAttr); 269 parsedParams.push_back(rightAttr); 270 return success(); 271 }; 272 273 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { 274 printer << "<" << params[0] << ":" << params[1] << ">"; 275 }; 276 277 return DynamicAttrDefinition::get("dynamic_custom_assembly_format", 278 testDialect, std::move(verifier), 279 std::move(parser), std::move(printer)); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // TestDialect 284 //===----------------------------------------------------------------------===// 285 286 void TestDialect::registerAttributes() { 287 addAttributes< 288 #define GET_ATTRDEF_LIST 289 #include "TestAttrDefs.cpp.inc" 290 >(); 291 registerDynamicAttr(getDynamicSingletonAttr(this)); 292 registerDynamicAttr(getDynamicPairAttr(this)); 293 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this)); 294 } 295