1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/Types.h" 19 #include "llvm/ADT/Hashing.h" 20 #include "llvm/ADT/SetVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace test; 25 26 //===----------------------------------------------------------------------===// 27 // AttrWithSelfTypeParamAttr 28 //===----------------------------------------------------------------------===// 29 30 Attribute AttrWithSelfTypeParamAttr::parse(DialectAsmParser &parser, 31 Type type) { 32 Type selfType; 33 if (parser.parseType(selfType)) 34 return Attribute(); 35 return get(parser.getContext(), selfType); 36 } 37 38 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const { 39 printer << "attr_with_self_type_param " << getType(); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // AttrWithTypeBuilderAttr 44 //===----------------------------------------------------------------------===// 45 46 Attribute AttrWithTypeBuilderAttr::parse(DialectAsmParser &parser, Type type) { 47 IntegerAttr element; 48 if (parser.parseAttribute(element)) 49 return Attribute(); 50 return get(parser.getContext(), element); 51 } 52 53 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const { 54 printer << "attr_with_type_builder " << getAttr(); 55 } 56 57 //===----------------------------------------------------------------------===// 58 // CompoundAAttr 59 //===----------------------------------------------------------------------===// 60 61 Attribute CompoundAAttr::parse(DialectAsmParser &parser, Type type) { 62 int widthOfSomething; 63 Type oneType; 64 SmallVector<int, 4> arrayOfInts; 65 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 66 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 67 parser.parseLSquare()) 68 return Attribute(); 69 70 int intVal; 71 while (!*parser.parseOptionalInteger(intVal)) { 72 arrayOfInts.push_back(intVal); 73 if (parser.parseOptionalComma()) 74 break; 75 } 76 77 if (parser.parseRSquare() || parser.parseGreater()) 78 return Attribute(); 79 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); 80 } 81 82 void CompoundAAttr::print(DialectAsmPrinter &printer) const { 83 printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() 84 << ", ["; 85 llvm::interleaveComma(getArrayOfInts(), printer); 86 printer << "]>"; 87 } 88 89 //===----------------------------------------------------------------------===// 90 // CompoundAAttr 91 //===----------------------------------------------------------------------===// 92 93 Attribute TestI64ElementsAttr::parse(DialectAsmParser &parser, Type type) { 94 SmallVector<uint64_t> elements; 95 if (parser.parseLess() || parser.parseLSquare()) 96 return Attribute(); 97 uint64_t intVal; 98 while (succeeded(*parser.parseOptionalInteger(intVal))) { 99 elements.push_back(intVal); 100 if (parser.parseOptionalComma()) 101 break; 102 } 103 104 if (parser.parseRSquare() || parser.parseGreater()) 105 return Attribute(); 106 return parser.getChecked<TestI64ElementsAttr>( 107 parser.getContext(), type.cast<ShapedType>(), elements); 108 } 109 110 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const { 111 printer << "i64_elements<["; 112 llvm::interleaveComma(getElements(), printer); 113 printer << "] : " << getType() << ">"; 114 } 115 116 LogicalResult 117 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 118 ShapedType type, ArrayRef<uint64_t> elements) { 119 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 120 return emitError() 121 << "number of elements does not match the provided shape type, got: " 122 << elements.size() << ", but expected: " << type.getNumElements(); 123 } 124 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 125 return emitError() << "expected single rank 64-bit shape type, but got: " 126 << type; 127 return success(); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // Tablegen Generated Definitions 132 //===----------------------------------------------------------------------===// 133 134 #include "TestAttrInterfaces.cpp.inc" 135 136 #define GET_ATTRDEF_CLASSES 137 #include "TestAttrDefs.cpp.inc" 138 139 //===----------------------------------------------------------------------===// 140 // TestDialect 141 //===----------------------------------------------------------------------===// 142 143 void TestDialect::registerAttributes() { 144 addAttributes< 145 #define GET_ATTRDEF_LIST 146 #include "TestAttrDefs.cpp.inc" 147 >(); 148 } 149 150 Attribute TestDialect::parseAttribute(DialectAsmParser &parser, 151 Type type) const { 152 StringRef attrTag; 153 if (failed(parser.parseKeyword(&attrTag))) 154 return Attribute(); 155 { 156 Attribute attr; 157 auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); 158 if (parseResult.hasValue()) 159 return attr; 160 } 161 parser.emitError(parser.getNameLoc(), "unknown test attribute"); 162 return Attribute(); 163 } 164 165 void TestDialect::printAttribute(Attribute attr, 166 DialectAsmPrinter &printer) const { 167 if (succeeded(generatedAttributePrinter(attr, printer))) 168 return; 169 } 170