1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/Types.h" 19 #include "mlir/Support/LogicalResult.h" 20 #include "llvm/ADT/Hashing.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/ADT/bit.h" 24 25 using namespace mlir; 26 using namespace test; 27 28 //===----------------------------------------------------------------------===// 29 // AttrWithSelfTypeParamAttr 30 //===----------------------------------------------------------------------===// 31 32 Attribute AttrWithSelfTypeParamAttr::parse(DialectAsmParser &parser, 33 Type type) { 34 Type selfType; 35 if (parser.parseType(selfType)) 36 return Attribute(); 37 return get(parser.getContext(), selfType); 38 } 39 40 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const { 41 printer << " " << getType(); 42 } 43 44 //===----------------------------------------------------------------------===// 45 // AttrWithTypeBuilderAttr 46 //===----------------------------------------------------------------------===// 47 48 Attribute AttrWithTypeBuilderAttr::parse(DialectAsmParser &parser, Type type) { 49 IntegerAttr element; 50 if (parser.parseAttribute(element)) 51 return Attribute(); 52 return get(parser.getContext(), element); 53 } 54 55 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const { 56 printer << " " << getAttr(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // CompoundAAttr 61 //===----------------------------------------------------------------------===// 62 63 Attribute CompoundAAttr::parse(DialectAsmParser &parser, Type type) { 64 int widthOfSomething; 65 Type oneType; 66 SmallVector<int, 4> arrayOfInts; 67 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 68 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 69 parser.parseLSquare()) 70 return Attribute(); 71 72 int intVal; 73 while (!*parser.parseOptionalInteger(intVal)) { 74 arrayOfInts.push_back(intVal); 75 if (parser.parseOptionalComma()) 76 break; 77 } 78 79 if (parser.parseRSquare() || parser.parseGreater()) 80 return Attribute(); 81 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 82 } 83 84 void CompoundAAttr::print(DialectAsmPrinter &printer) const { 85 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; 86 llvm::interleaveComma(getArrayOfInts(), printer); 87 printer << "]>"; 88 } 89 90 //===----------------------------------------------------------------------===// 91 // CompoundAAttr 92 //===----------------------------------------------------------------------===// 93 94 Attribute TestI64ElementsAttr::parse(DialectAsmParser &parser, Type type) { 95 SmallVector<uint64_t> elements; 96 if (parser.parseLess() || parser.parseLSquare()) 97 return Attribute(); 98 uint64_t intVal; 99 while (succeeded(*parser.parseOptionalInteger(intVal))) { 100 elements.push_back(intVal); 101 if (parser.parseOptionalComma()) 102 break; 103 } 104 105 if (parser.parseRSquare() || parser.parseGreater()) 106 return Attribute(); 107 return parser.getChecked<TestI64ElementsAttr>( 108 parser.getContext(), type.cast<ShapedType>(), elements); 109 } 110 111 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const { 112 printer << "<["; 113 llvm::interleaveComma(getElements(), printer); 114 printer << "] : " << getType() << ">"; 115 } 116 117 LogicalResult 118 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 119 ShapedType type, ArrayRef<uint64_t> elements) { 120 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 121 return emitError() 122 << "number of elements does not match the provided shape type, got: " 123 << elements.size() << ", but expected: " << type.getNumElements(); 124 } 125 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 126 return emitError() << "expected single rank 64-bit shape type, but got: " 127 << type; 128 return success(); 129 } 130 131 LogicalResult 132 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 133 int64_t one, std::string two, IntegerAttr three, 134 ArrayRef<int> four) { 135 if (four.size() != static_cast<unsigned>(one)) 136 return emitError() << "expected 'one' to equal 'four.size()'"; 137 return success(); 138 } 139 140 //===----------------------------------------------------------------------===// 141 // Utility Functions for Generated Attributes 142 //===----------------------------------------------------------------------===// 143 144 static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) { 145 SmallVector<int> ints; 146 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { 147 ints.push_back(0); 148 return parser.parseInteger(ints.back()); 149 }) || 150 parser.parseRSquare()) 151 return failure(); 152 return ints; 153 } 154 155 static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) { 156 printer << '['; 157 llvm::interleaveComma(ints, printer); 158 printer << ']'; 159 } 160 161 //===----------------------------------------------------------------------===// 162 // TestSubElementsAccessAttr 163 //===----------------------------------------------------------------------===// 164 165 Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser, 166 ::mlir::Type type) { 167 Attribute first, second, third; 168 if (parser.parseLess() || parser.parseAttribute(first) || 169 parser.parseComma() || parser.parseAttribute(second) || 170 parser.parseComma() || parser.parseAttribute(third) || 171 parser.parseGreater()) { 172 return {}; 173 } 174 return get(parser.getContext(), first, second, third); 175 } 176 177 void TestSubElementsAccessAttr::print( 178 ::mlir::DialectAsmPrinter &printer) const { 179 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() 180 << ">"; 181 } 182 183 void TestSubElementsAccessAttr::walkImmediateSubElements( 184 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn, 185 llvm::function_ref<void(mlir::Type)> walkTypesFn) const { 186 walkAttrsFn(getFirst()); 187 walkAttrsFn(getSecond()); 188 walkAttrsFn(getThird()); 189 } 190 191 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( 192 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 193 Attribute first = getFirst(); 194 Attribute second = getSecond(); 195 Attribute third = getThird(); 196 for (auto &it : replacements) { 197 switch (it.first) { 198 case 0: 199 first = it.second; 200 break; 201 case 1: 202 second = it.second; 203 break; 204 case 2: 205 third = it.second; 206 break; 207 } 208 } 209 return get(getContext(), first, second, third); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // Tablegen Generated Definitions 214 //===----------------------------------------------------------------------===// 215 216 #include "TestAttrInterfaces.cpp.inc" 217 218 #define GET_ATTRDEF_CLASSES 219 #include "TestAttrDefs.cpp.inc" 220 221 //===----------------------------------------------------------------------===// 222 // TestDialect 223 //===----------------------------------------------------------------------===// 224 225 void TestDialect::registerAttributes() { 226 addAttributes< 227 #define GET_ATTRDEF_LIST 228 #include "TestAttrDefs.cpp.inc" 229 >(); 230 } 231