1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/Types.h" 19 #include "mlir/Support/LogicalResult.h" 20 #include "llvm/ADT/Hashing.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/ADT/bit.h" 24 25 using namespace mlir; 26 using namespace test; 27 28 //===----------------------------------------------------------------------===// 29 // AttrWithSelfTypeParamAttr 30 //===----------------------------------------------------------------------===// 31 32 Attribute AttrWithSelfTypeParamAttr::parse(AsmParser &parser, Type type) { 33 Type selfType; 34 if (parser.parseType(selfType)) 35 return Attribute(); 36 return get(parser.getContext(), selfType); 37 } 38 39 void AttrWithSelfTypeParamAttr::print(AsmPrinter &printer) const { 40 printer << " " << getType(); 41 } 42 43 //===----------------------------------------------------------------------===// 44 // AttrWithTypeBuilderAttr 45 //===----------------------------------------------------------------------===// 46 47 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) { 48 IntegerAttr element; 49 if (parser.parseAttribute(element)) 50 return Attribute(); 51 return get(parser.getContext(), element); 52 } 53 54 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const { 55 printer << " " << getAttr(); 56 } 57 58 //===----------------------------------------------------------------------===// 59 // CompoundAAttr 60 //===----------------------------------------------------------------------===// 61 62 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { 63 int widthOfSomething; 64 Type oneType; 65 SmallVector<int, 4> arrayOfInts; 66 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 67 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 68 parser.parseLSquare()) 69 return Attribute(); 70 71 int intVal; 72 while (!*parser.parseOptionalInteger(intVal)) { 73 arrayOfInts.push_back(intVal); 74 if (parser.parseOptionalComma()) 75 break; 76 } 77 78 if (parser.parseRSquare() || parser.parseGreater()) 79 return Attribute(); 80 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 81 } 82 83 void CompoundAAttr::print(AsmPrinter &printer) const { 84 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 85 llvm::interleaveComma(getArrayOfInts(), printer); 86 printer << "]>"; 87 } 88 89 //===----------------------------------------------------------------------===// 90 // CompoundAAttr 91 //===----------------------------------------------------------------------===// 92 93 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { 94 SmallVector<uint64_t> elements; 95 if (parser.parseLess() || parser.parseLSquare()) 96 return Attribute(); 97 uint64_t intVal; 98 while (succeeded(*parser.parseOptionalInteger(intVal))) { 99 elements.push_back(intVal); 100 if (parser.parseOptionalComma()) 101 break; 102 } 103 104 if (parser.parseRSquare() || parser.parseGreater()) 105 return Attribute(); 106 return parser.getChecked<TestI64ElementsAttr>( 107 parser.getContext(), type.cast<ShapedType>(), elements); 108 } 109 110 void TestI64ElementsAttr::print(AsmPrinter &printer) const { 111 printer << "<["; 112 llvm::interleaveComma(getElements(), printer); 113 printer << "] : " << getType() << ">"; 114 } 115 116 LogicalResult 117 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 118 ShapedType type, ArrayRef<uint64_t> elements) { 119 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 120 return emitError() 121 << "number of elements does not match the provided shape type, got: " 122 << elements.size() << ", but expected: " << type.getNumElements(); 123 } 124 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 125 return emitError() << "expected single rank 64-bit shape type, but got: " 126 << type; 127 return success(); 128 } 129 130 LogicalResult 131 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 132 int64_t one, std::string two, IntegerAttr three, 133 ArrayRef<int> four) { 134 if (four.size() != static_cast<unsigned>(one)) 135 return emitError() << "expected 'one' to equal 'four.size()'"; 136 return success(); 137 } 138 139 //===----------------------------------------------------------------------===// 140 // Utility Functions for Generated Attributes 141 //===----------------------------------------------------------------------===// 142 143 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { 144 SmallVector<int> ints; 145 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { 146 ints.push_back(0); 147 return parser.parseInteger(ints.back()); 148 }) || 149 parser.parseRSquare()) 150 return failure(); 151 return ints; 152 } 153 154 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { 155 printer << '['; 156 llvm::interleaveComma(ints, printer); 157 printer << ']'; 158 } 159 160 //===----------------------------------------------------------------------===// 161 // TestSubElementsAccessAttr 162 //===----------------------------------------------------------------------===// 163 164 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, 165 ::mlir::Type type) { 166 Attribute first, second, third; 167 if (parser.parseLess() || parser.parseAttribute(first) || 168 parser.parseComma() || parser.parseAttribute(second) || 169 parser.parseComma() || parser.parseAttribute(third) || 170 parser.parseGreater()) { 171 return {}; 172 } 173 return get(parser.getContext(), first, second, third); 174 } 175 176 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { 177 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() 178 << ">"; 179 } 180 181 void TestSubElementsAccessAttr::walkImmediateSubElements( 182 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn, 183 llvm::function_ref<void(mlir::Type)> walkTypesFn) const { 184 walkAttrsFn(getFirst()); 185 walkAttrsFn(getSecond()); 186 walkAttrsFn(getThird()); 187 } 188 189 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( 190 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 191 Attribute first = getFirst(); 192 Attribute second = getSecond(); 193 Attribute third = getThird(); 194 for (auto &it : replacements) { 195 switch (it.first) { 196 case 0: 197 first = it.second; 198 break; 199 case 1: 200 second = it.second; 201 break; 202 case 2: 203 third = it.second; 204 break; 205 } 206 } 207 return get(getContext(), first, second, third); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // Tablegen Generated Definitions 212 //===----------------------------------------------------------------------===// 213 214 #include "TestAttrInterfaces.cpp.inc" 215 216 #define GET_ATTRDEF_CLASSES 217 #include "TestAttrDefs.cpp.inc" 218 219 //===----------------------------------------------------------------------===// 220 // TestDialect 221 //===----------------------------------------------------------------------===// 222 223 void TestDialect::registerAttributes() { 224 addAttributes< 225 #define GET_ATTRDEF_LIST 226 #include "TestAttrDefs.cpp.inc" 227 >(); 228 } 229