1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains attributes defined by the TestDialect for testing various 10 // features of MLIR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestAttributes.h" 15 #include "TestDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/Types.h" 19 #include "llvm/ADT/Hashing.h" 20 #include "llvm/ADT/SetVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace test; 25 26 //===----------------------------------------------------------------------===// 27 // AttrWithSelfTypeParamAttr 28 //===----------------------------------------------------------------------===// 29 30 Attribute AttrWithSelfTypeParamAttr::parse(MLIRContext *context, 31 DialectAsmParser &parser, 32 Type type) { 33 Type selfType; 34 if (parser.parseType(selfType)) 35 return Attribute(); 36 return get(context, selfType); 37 } 38 39 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const { 40 printer << "attr_with_self_type_param " << getType(); 41 } 42 43 //===----------------------------------------------------------------------===// 44 // AttrWithTypeBuilderAttr 45 //===----------------------------------------------------------------------===// 46 47 Attribute AttrWithTypeBuilderAttr::parse(MLIRContext *context, 48 DialectAsmParser &parser, Type type) { 49 IntegerAttr element; 50 if (parser.parseAttribute(element)) 51 return Attribute(); 52 return get(context, element); 53 } 54 55 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const { 56 printer << "attr_with_type_builder " << getAttr(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // CompoundAAttr 61 //===----------------------------------------------------------------------===// 62 63 Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser, 64 Type type) { 65 int widthOfSomething; 66 Type oneType; 67 SmallVector<int, 4> arrayOfInts; 68 if (parser.parseLess() || parser.parseInteger(widthOfSomething) || 69 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || 70 parser.parseLSquare()) 71 return Attribute(); 72 73 int intVal; 74 while (!*parser.parseOptionalInteger(intVal)) { 75 arrayOfInts.push_back(intVal); 76 if (parser.parseOptionalComma()) 77 break; 78 } 79 80 if (parser.parseRSquare() || parser.parseGreater()) 81 return Attribute(); 82 return get(context, widthOfSomething, oneType, arrayOfInts); 83 } 84 85 void CompoundAAttr::print(DialectAsmPrinter &printer) const { 86 printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() 87 << ", ["; 88 llvm::interleaveComma(getArrayOfInts(), printer); 89 printer << "]>"; 90 } 91 92 //===----------------------------------------------------------------------===// 93 // CompoundAAttr 94 //===----------------------------------------------------------------------===// 95 96 Attribute TestI64ElementsAttr::parse(MLIRContext *context, 97 DialectAsmParser &parser, Type type) { 98 SmallVector<uint64_t> elements; 99 if (parser.parseLess() || parser.parseLSquare()) 100 return Attribute(); 101 uint64_t intVal; 102 while (succeeded(*parser.parseOptionalInteger(intVal))) { 103 elements.push_back(intVal); 104 if (parser.parseOptionalComma()) 105 break; 106 } 107 108 if (parser.parseRSquare() || parser.parseGreater()) 109 return Attribute(); 110 return parser.getChecked<TestI64ElementsAttr>( 111 context, type.cast<ShapedType>(), elements); 112 } 113 114 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const { 115 printer << "i64_elements<["; 116 llvm::interleaveComma(getElements(), printer); 117 printer << "] : " << getType() << ">"; 118 } 119 120 LogicalResult 121 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 122 ShapedType type, ArrayRef<uint64_t> elements) { 123 if (type.getNumElements() != static_cast<int64_t>(elements.size())) { 124 return emitError() 125 << "number of elements does not match the provided shape type, got: " 126 << elements.size() << ", but expected: " << type.getNumElements(); 127 } 128 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) 129 return emitError() << "expected single rank 64-bit shape type, but got: " 130 << type; 131 return success(); 132 } 133 134 //===----------------------------------------------------------------------===// 135 // Tablegen Generated Definitions 136 //===----------------------------------------------------------------------===// 137 138 #include "TestAttrInterfaces.cpp.inc" 139 140 #define GET_ATTRDEF_CLASSES 141 #include "TestAttrDefs.cpp.inc" 142 143 //===----------------------------------------------------------------------===// 144 // TestDialect 145 //===----------------------------------------------------------------------===// 146 147 void TestDialect::registerAttributes() { 148 addAttributes< 149 #define GET_ATTRDEF_LIST 150 #include "TestAttrDefs.cpp.inc" 151 >(); 152 } 153 154 Attribute TestDialect::parseAttribute(DialectAsmParser &parser, 155 Type type) const { 156 StringRef attrTag; 157 if (failed(parser.parseKeyword(&attrTag))) 158 return Attribute(); 159 { 160 Attribute attr; 161 auto parseResult = 162 generatedAttributeParser(getContext(), parser, attrTag, type, attr); 163 if (parseResult.hasValue()) 164 return attr; 165 } 166 parser.emitError(parser.getNameLoc(), "unknown test attribute"); 167 return Attribute(); 168 } 169 170 void TestDialect::printAttribute(Attribute attr, 171 DialectAsmPrinter &printer) const { 172 if (succeeded(generatedAttributePrinter(attr, printer))) 173 return; 174 } 175