1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/ExtensibleDialect.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Support/LogicalResult.h" 21 #include "llvm/ADT/Hashing.h" 22 #include "llvm/ADT/SetVector.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/ErrorHandling.h" 26 27 using namespace mlir; 28 using namespace test; 29 30 //===----------------------------------------------------------------------===// 31 // AttrWithSelfTypeParamAttr 32 //===----------------------------------------------------------------------===// 33 34 Attribute AttrWithSelfTypeParamAttr::parse(AsmParser &parser, Type type) { 35 Type selfType; 36 if (parser.parseType(selfType)) 37 return Attribute(); 38 return get(parser.getContext(), selfType); 39 } 40 41 void AttrWithSelfTypeParamAttr::print(AsmPrinter &printer) const { 42 printer << " " << getType(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // AttrWithTypeBuilderAttr 47 //===----------------------------------------------------------------------===// 48 49 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) { 50 IntegerAttr element; 51 if (parser.parseAttribute(element)) 52 return Attribute(); 53 return get(parser.getContext(), element); 54 } 55 56 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const { 57 printer << " " << getAttr(); 58 } 59 60 //===----------------------------------------------------------------------===// 61 // CompoundAAttr 62 //===----------------------------------------------------------------------===// 63 64 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { 65 int widthOfSomething; 66 Type oneType; 67 SmallVector<int, 4> arrayOfInts; 68 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 69 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 70 parser.parseLSquare()) 71 return Attribute(); 72 73 int intVal; 74 while (!*parser.parseOptionalInteger(intVal)) { 75 arrayOfInts.push_back(intVal); 76 if (parser.parseOptionalComma()) 77 break; 78 } 79 80 if (parser.parseRSquare() || parser.parseGreater()) 81 return Attribute(); 82 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 83 } 84 85 void CompoundAAttr::print(AsmPrinter &printer) const { 86 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 87 llvm::interleaveComma(getArrayOfInts(), printer); 88 printer << "]>"; 89 } 90 91 //===----------------------------------------------------------------------===// 92 // CompoundAAttr 93 //===----------------------------------------------------------------------===// 94 95 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { 96 SmallVector<uint64_t> elements; 97 if (parser.parseLess() || parser.parseLSquare()) 98 return Attribute(); 99 uint64_t intVal; 100 while (succeeded(*parser.parseOptionalInteger(intVal))) { 101 elements.push_back(intVal); 102 if (parser.parseOptionalComma()) 103 break; 104 } 105 106 if (parser.parseRSquare() || parser.parseGreater()) 107 return Attribute(); 108 return parser.getChecked<TestI64ElementsAttr>( 109 parser.getContext(), type.cast<ShapedType>(), elements); 110 } 111 112 void TestI64ElementsAttr::print(AsmPrinter &printer) const { 113 printer << "<["; 114 llvm::interleaveComma(getElements(), printer); 115 printer << "] : " << getType() << ">"; 116 } 117 118 LogicalResult 119 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 120 ShapedType type, ArrayRef<uint64_t> elements) { 121 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 122 return emitError() 123 << "number of elements does not match the provided shape type, got: " 124 << elements.size() << ", but expected: " << type.getNumElements(); 125 } 126 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 127 return emitError() << "expected single rank 64-bit shape type, but got: " 128 << type; 129 return success(); 130 } 131 132 LogicalResult 133 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 134 int64_t one, std::string two, IntegerAttr three, 135 ArrayRef<int> four) { 136 if (four.size() != static_cast<unsigned>(one)) 137 return emitError() << "expected 'one' to equal 'four.size()'"; 138 return success(); 139 } 140 141 //===----------------------------------------------------------------------===// 142 // Utility Functions for Generated Attributes 143 //===----------------------------------------------------------------------===// 144 145 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { 146 SmallVector<int> ints; 147 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { 148 ints.push_back(0); 149 return parser.parseInteger(ints.back()); 150 }) || 151 parser.parseRSquare()) 152 return failure(); 153 return ints; 154 } 155 156 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { 157 printer << '['; 158 llvm::interleaveComma(ints, printer); 159 printer << ']'; 160 } 161 162 //===----------------------------------------------------------------------===// 163 // TestSubElementsAccessAttr 164 //===----------------------------------------------------------------------===// 165 166 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, 167 ::mlir::Type type) { 168 Attribute first, second, third; 169 if (parser.parseLess() || parser.parseAttribute(first) || 170 parser.parseComma() || parser.parseAttribute(second) || 171 parser.parseComma() || parser.parseAttribute(third) || 172 parser.parseGreater()) { 173 return {}; 174 } 175 return get(parser.getContext(), first, second, third); 176 } 177 178 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { 179 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() 180 << ">"; 181 } 182 183 void TestSubElementsAccessAttr::walkImmediateSubElements( 184 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn, 185 llvm::function_ref<void(mlir::Type)> walkTypesFn) const { 186 walkAttrsFn(getFirst()); 187 walkAttrsFn(getSecond()); 188 walkAttrsFn(getThird()); 189 } 190 191 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( 192 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 193 Attribute first = getFirst(); 194 Attribute second = getSecond(); 195 Attribute third = getThird(); 196 for (auto &it : replacements) { 197 switch (it.first) { 198 case 0: 199 first = it.second; 200 break; 201 case 1: 202 second = it.second; 203 break; 204 case 2: 205 third = it.second; 206 break; 207 } 208 } 209 return get(getContext(), first, second, third); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // Tablegen Generated Definitions 214 //===----------------------------------------------------------------------===// 215 216 #include "TestAttrInterfaces.cpp.inc" 217 218 #define GET_ATTRDEF_CLASSES 219 #include "TestAttrDefs.cpp.inc" 220 221 //===----------------------------------------------------------------------===// 222 // Dynamic Attributes 223 //===----------------------------------------------------------------------===// 224 225 /// Define a singleton dynamic attribute. 226 static std::unique_ptr<DynamicAttrDefinition> 227 getDynamicSingletonAttr(TestDialect *testDialect) { 228 return DynamicAttrDefinition::get( 229 "dynamic_singleton", testDialect, 230 [](function_ref<InFlightDiagnostic()> emitError, 231 ArrayRef<Attribute> args) { 232 if (!args.empty()) { 233 emitError() << "expected 0 attribute arguments, but had " 234 << args.size(); 235 return failure(); 236 } 237 return success(); 238 }); 239 } 240 241 /// Define a dynamic attribute representing a pair or attributes. 242 static std::unique_ptr<DynamicAttrDefinition> 243 getDynamicPairAttr(TestDialect *testDialect) { 244 return DynamicAttrDefinition::get( 245 "dynamic_pair", testDialect, 246 [](function_ref<InFlightDiagnostic()> emitError, 247 ArrayRef<Attribute> args) { 248 if (args.size() != 2) { 249 emitError() << "expected 2 attribute arguments, but had " 250 << args.size(); 251 return failure(); 252 } 253 return success(); 254 }); 255 } 256 257 static std::unique_ptr<DynamicAttrDefinition> 258 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { 259 auto verifier = [](function_ref<InFlightDiagnostic()> emitError, 260 ArrayRef<Attribute> args) { 261 if (args.size() != 2) { 262 emitError() << "expected 2 attribute arguments, but had " << args.size(); 263 return failure(); 264 } 265 return success(); 266 }; 267 268 auto parser = [](AsmParser &parser, 269 llvm::SmallVectorImpl<Attribute> &parsedParams) { 270 Attribute leftAttr, rightAttr; 271 if (parser.parseLess() || parser.parseAttribute(leftAttr) || 272 parser.parseColon() || parser.parseAttribute(rightAttr) || 273 parser.parseGreater()) 274 return failure(); 275 parsedParams.push_back(leftAttr); 276 parsedParams.push_back(rightAttr); 277 return success(); 278 }; 279 280 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { 281 printer << "<" << params[0] << ":" << params[1] << ">"; 282 }; 283 284 return DynamicAttrDefinition::get("dynamic_custom_assembly_format", 285 testDialect, std::move(verifier), 286 std::move(parser), std::move(printer)); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // TestDialect 291 //===----------------------------------------------------------------------===// 292 293 void TestDialect::registerAttributes() { 294 addAttributes< 295 #define GET_ATTRDEF_LIST 296 #include "TestAttrDefs.cpp.inc" 297 >(); 298 registerDynamicAttr(getDynamicSingletonAttr(this)); 299 registerDynamicAttr(getDynamicPairAttr(this)); 300 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this)); 301 } 302