1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/ExtensibleDialect.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Support/LogicalResult.h" 21 #include "llvm/ADT/Hashing.h" 22 #include "llvm/ADT/SetVector.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/ErrorHandling.h" 26 27 using namespace mlir; 28 using namespace test; 29 30 //===----------------------------------------------------------------------===// 31 // AttrWithTypeBuilderAttr 32 //===----------------------------------------------------------------------===// 33 34 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) { 35 IntegerAttr element; 36 if (parser.parseAttribute(element)) 37 return Attribute(); 38 return get(parser.getContext(), element); 39 } 40 41 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const { 42 printer << " " << getAttr(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 // CompoundAAttr 47 //===----------------------------------------------------------------------===// 48 49 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { 50 int widthOfSomething; 51 Type oneType; 52 SmallVector<int, 4> arrayOfInts; 53 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 54 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 55 parser.parseLSquare()) 56 return Attribute(); 57 58 int intVal; 59 while (!*parser.parseOptionalInteger(intVal)) { 60 arrayOfInts.push_back(intVal); 61 if (parser.parseOptionalComma()) 62 break; 63 } 64 65 if (parser.parseRSquare() || parser.parseGreater()) 66 return Attribute(); 67 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 68 } 69 70 void CompoundAAttr::print(AsmPrinter &printer) const { 71 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 72 llvm::interleaveComma(getArrayOfInts(), printer); 73 printer << "]>"; 74 } 75 76 //===----------------------------------------------------------------------===// 77 // CompoundAAttr 78 //===----------------------------------------------------------------------===// 79 80 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { 81 SmallVector<uint64_t> elements; 82 if (parser.parseLess() || parser.parseLSquare()) 83 return Attribute(); 84 uint64_t intVal; 85 while (succeeded(*parser.parseOptionalInteger(intVal))) { 86 elements.push_back(intVal); 87 if (parser.parseOptionalComma()) 88 break; 89 } 90 91 if (parser.parseRSquare() || parser.parseGreater()) 92 return Attribute(); 93 return parser.getChecked<TestI64ElementsAttr>( 94 parser.getContext(), type.cast<ShapedType>(), elements); 95 } 96 97 void TestI64ElementsAttr::print(AsmPrinter &printer) const { 98 printer << "<["; 99 llvm::interleaveComma(getElements(), printer); 100 printer << "] : " << getType() << ">"; 101 } 102 103 LogicalResult 104 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 105 ShapedType type, ArrayRef<uint64_t> elements) { 106 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 107 return emitError() 108 << "number of elements does not match the provided shape type, got: " 109 << elements.size() << ", but expected: " << type.getNumElements(); 110 } 111 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 112 return emitError() << "expected single rank 64-bit shape type, but got: " 113 << type; 114 return success(); 115 } 116 117 LogicalResult TestAttrWithFormatAttr::verify( 118 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two, 119 IntegerAttr three, ArrayRef<int> four, 120 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) { 121 if (four.size() != static_cast<unsigned>(one)) 122 return emitError() << "expected 'one' to equal 'four.size()'"; 123 return success(); 124 } 125 126 //===----------------------------------------------------------------------===// 127 // Utility Functions for Generated Attributes 128 //===----------------------------------------------------------------------===// 129 130 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { 131 SmallVector<int> ints; 132 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { 133 ints.push_back(0); 134 return parser.parseInteger(ints.back()); 135 }) || 136 parser.parseRSquare()) 137 return failure(); 138 return ints; 139 } 140 141 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { 142 printer << '['; 143 llvm::interleaveComma(ints, printer); 144 printer << ']'; 145 } 146 147 //===----------------------------------------------------------------------===// 148 // TestSubElementsAccessAttr 149 //===----------------------------------------------------------------------===// 150 151 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, 152 ::mlir::Type type) { 153 Attribute first, second, third; 154 if (parser.parseLess() || parser.parseAttribute(first) || 155 parser.parseComma() || parser.parseAttribute(second) || 156 parser.parseComma() || parser.parseAttribute(third) || 157 parser.parseGreater()) { 158 return {}; 159 } 160 return get(parser.getContext(), first, second, third); 161 } 162 163 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { 164 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() 165 << ">"; 166 } 167 168 void TestSubElementsAccessAttr::walkImmediateSubElements( 169 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn, 170 llvm::function_ref<void(mlir::Type)> walkTypesFn) const { 171 walkAttrsFn(getFirst()); 172 walkAttrsFn(getSecond()); 173 walkAttrsFn(getThird()); 174 } 175 176 Attribute TestSubElementsAccessAttr::replaceImmediateSubElements( 177 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const { 178 assert(replAttrs.size() == 3 && "invalid number of replacement attributes"); 179 return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // TestExtern1DI64ElementsAttr 184 //===----------------------------------------------------------------------===// 185 186 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const { 187 return getHandle().getData()->getData(); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // Tablegen Generated Definitions 192 //===----------------------------------------------------------------------===// 193 194 #include "TestAttrInterfaces.cpp.inc" 195 196 #define GET_ATTRDEF_CLASSES 197 #include "TestAttrDefs.cpp.inc" 198 199 //===----------------------------------------------------------------------===// 200 // Dynamic Attributes 201 //===----------------------------------------------------------------------===// 202 203 /// Define a singleton dynamic attribute. 204 static std::unique_ptr<DynamicAttrDefinition> 205 getDynamicSingletonAttr(TestDialect *testDialect) { 206 return DynamicAttrDefinition::get( 207 "dynamic_singleton", testDialect, 208 [](function_ref<InFlightDiagnostic()> emitError, 209 ArrayRef<Attribute> args) { 210 if (!args.empty()) { 211 emitError() << "expected 0 attribute arguments, but had " 212 << args.size(); 213 return failure(); 214 } 215 return success(); 216 }); 217 } 218 219 /// Define a dynamic attribute representing a pair or attributes. 220 static std::unique_ptr<DynamicAttrDefinition> 221 getDynamicPairAttr(TestDialect *testDialect) { 222 return DynamicAttrDefinition::get( 223 "dynamic_pair", testDialect, 224 [](function_ref<InFlightDiagnostic()> emitError, 225 ArrayRef<Attribute> args) { 226 if (args.size() != 2) { 227 emitError() << "expected 2 attribute arguments, but had " 228 << args.size(); 229 return failure(); 230 } 231 return success(); 232 }); 233 } 234 235 static std::unique_ptr<DynamicAttrDefinition> 236 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { 237 auto verifier = [](function_ref<InFlightDiagnostic()> emitError, 238 ArrayRef<Attribute> args) { 239 if (args.size() != 2) { 240 emitError() << "expected 2 attribute arguments, but had " << args.size(); 241 return failure(); 242 } 243 return success(); 244 }; 245 246 auto parser = [](AsmParser &parser, 247 llvm::SmallVectorImpl<Attribute> &parsedParams) { 248 Attribute leftAttr, rightAttr; 249 if (parser.parseLess() || parser.parseAttribute(leftAttr) || 250 parser.parseColon() || parser.parseAttribute(rightAttr) || 251 parser.parseGreater()) 252 return failure(); 253 parsedParams.push_back(leftAttr); 254 parsedParams.push_back(rightAttr); 255 return success(); 256 }; 257 258 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { 259 printer << "<" << params[0] << ":" << params[1] << ">"; 260 }; 261 262 return DynamicAttrDefinition::get("dynamic_custom_assembly_format", 263 testDialect, std::move(verifier), 264 std::move(parser), std::move(printer)); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // TestDialect 269 //===----------------------------------------------------------------------===// 270 271 void TestDialect::registerAttributes() { 272 addAttributes< 273 #define GET_ATTRDEF_LIST 274 #include "TestAttrDefs.cpp.inc" 275 >(); 276 registerDynamicAttr(getDynamicSingletonAttr(this)); 277 registerDynamicAttr(getDynamicPairAttr(this)); 278 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this)); 279 } 280