1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/Types.h" 19 #include "llvm/ADT/Hashing.h" 20 #include "llvm/ADT/SetVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace test; 25 26 //===----------------------------------------------------------------------===// 27 // AttrWithSelfTypeParamAttr 28 //===----------------------------------------------------------------------===// 29 30 Attribute AttrWithSelfTypeParamAttr::parse(DialectAsmParser &parser, 31 Type type) { 32 Type selfType; 33 if (parser.parseType(selfType)) 34 return Attribute(); 35 return get(parser.getContext(), selfType); 36 } 37 38 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const { 39 printer << "attr_with_self_type_param " << getType(); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // AttrWithTypeBuilderAttr 44 //===----------------------------------------------------------------------===// 45 46 Attribute AttrWithTypeBuilderAttr::parse(DialectAsmParser &parser, Type type) { 47 IntegerAttr element; 48 if (parser.parseAttribute(element)) 49 return Attribute(); 50 return get(parser.getContext(), element); 51 } 52 53 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const { 54 printer << "attr_with_type_builder " << getAttr(); 55 } 56 57 //===----------------------------------------------------------------------===// 58 // CompoundAAttr 59 //===----------------------------------------------------------------------===// 60 61 Attribute CompoundAAttr::parse(DialectAsmParser &parser, Type type) { 62 int widthOfSomething; 63 Type oneType; 64 SmallVector<int, 4> arrayOfInts; 65 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 66 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 67 parser.parseLSquare()) 68 return Attribute(); 69 70 int intVal; 71 while (!*parser.parseOptionalInteger(intVal)) { 72 arrayOfInts.push_back(intVal); 73 if (parser.parseOptionalComma()) 74 break; 75 } 76 77 if (parser.parseRSquare() || parser.parseGreater()) 78 return Attribute(); 79 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 80 } 81 82 void CompoundAAttr::print(DialectAsmPrinter &printer) const { 83 printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() 84 << ", ["; 85 llvm::interleaveComma(getArrayOfInts(), printer); 86 printer << "]>"; 87 } 88 89 //===----------------------------------------------------------------------===// 90 // CompoundAAttr 91 //===----------------------------------------------------------------------===// 92 93 Attribute TestI64ElementsAttr::parse(DialectAsmParser &parser, Type type) { 94 SmallVector<uint64_t> elements; 95 if (parser.parseLess() || parser.parseLSquare()) 96 return Attribute(); 97 uint64_t intVal; 98 while (succeeded(*parser.parseOptionalInteger(intVal))) { 99 elements.push_back(intVal); 100 if (parser.parseOptionalComma()) 101 break; 102 } 103 104 if (parser.parseRSquare() || parser.parseGreater()) 105 return Attribute(); 106 return parser.getChecked<TestI64ElementsAttr>( 107 parser.getContext(), type.cast<ShapedType>(), elements); 108 } 109 110 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const { 111 printer << "i64_elements<["; 112 llvm::interleaveComma(getElements(), printer); 113 printer << "] : " << getType() << ">"; 114 } 115 116 LogicalResult 117 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 118 ShapedType type, ArrayRef<uint64_t> elements) { 119 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 120 return emitError() 121 << "number of elements does not match the provided shape type, got: " 122 << elements.size() << ", but expected: " << type.getNumElements(); 123 } 124 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 125 return emitError() << "expected single rank 64-bit shape type, but got: " 126 << type; 127 return success(); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // TestSubElementsAccessAttr 132 //===----------------------------------------------------------------------===// 133 134 Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser, 135 ::mlir::Type type) { 136 Attribute first, second, third; 137 if (parser.parseLess() || parser.parseAttribute(first) || 138 parser.parseComma() || parser.parseAttribute(second) || 139 parser.parseComma() || parser.parseAttribute(third) || 140 parser.parseGreater()) { 141 return {}; 142 } 143 return get(parser.getContext(), first, second, third); 144 } 145 146 void TestSubElementsAccessAttr::print( 147 ::mlir::DialectAsmPrinter &printer) const { 148 printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", " 149 << getThird() << ">"; 150 } 151 152 void TestSubElementsAccessAttr::walkImmediateSubElements( 153 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn, 154 llvm::function_ref<void(mlir::Type)> walkTypesFn) const { 155 walkAttrsFn(getFirst()); 156 walkAttrsFn(getSecond()); 157 walkAttrsFn(getThird()); 158 } 159 160 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( 161 ArrayRef<std::pair<size_t, Attribute>> replacements) const { 162 Attribute first = getFirst(); 163 Attribute second = getSecond(); 164 Attribute third = getThird(); 165 for (auto &it : replacements) { 166 switch (it.first) { 167 case 0: 168 first = it.second; 169 break; 170 case 1: 171 second = it.second; 172 break; 173 case 2: 174 third = it.second; 175 break; 176 } 177 } 178 return get(getContext(), first, second, third); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // Tablegen Generated Definitions 183 //===----------------------------------------------------------------------===// 184 185 #include "TestAttrInterfaces.cpp.inc" 186 187 #define GET_ATTRDEF_CLASSES 188 #include "TestAttrDefs.cpp.inc" 189 190 //===----------------------------------------------------------------------===// 191 // TestDialect 192 //===----------------------------------------------------------------------===// 193 194 void TestDialect::registerAttributes() { 195 addAttributes< 196 #define GET_ATTRDEF_LIST 197 #include "TestAttrDefs.cpp.inc" 198 >(); 199 } 200 201 Attribute TestDialect::parseAttribute(DialectAsmParser &parser, 202 Type type) const { 203 StringRef attrTag; 204 if (failed(parser.parseKeyword(&attrTag))) 205 return Attribute(); 206 { 207 Attribute attr; 208 auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); 209 if (parseResult.hasValue()) 210 return attr; 211 } 212 parser.emitError(parser.getNameLoc(), "unknown test attribute"); 213 return Attribute(); 214 } 215 216 void TestDialect::printAttribute(Attribute attr, 217 DialectAsmPrinter &printer) const { 218 if (succeeded(generatedAttributePrinter(attr, printer))) 219 return; 220 } 221