1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace test;
25 
26 //===----------------------------------------------------------------------===//
27 // AttrWithSelfTypeParamAttr
28 //===----------------------------------------------------------------------===//
29 
30 Attribute AttrWithSelfTypeParamAttr::parse(MLIRContext *context,
31                                            DialectAsmParser &parser,
32                                            Type type) {
33   Type selfType;
34   if (parser.parseType(selfType))
35     return Attribute();
36   return get(context, selfType);
37 }
38 
39 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const {
40   printer << "attr_with_self_type_param " << getType();
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // AttrWithTypeBuilderAttr
45 //===----------------------------------------------------------------------===//
46 
47 Attribute AttrWithTypeBuilderAttr::parse(MLIRContext *context,
48                                          DialectAsmParser &parser, Type type) {
49   IntegerAttr element;
50   if (parser.parseAttribute(element))
51     return Attribute();
52   return get(context, element);
53 }
54 
55 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const {
56   printer << "attr_with_type_builder " << getAttr();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // CompoundAAttr
61 //===----------------------------------------------------------------------===//
62 
63 Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser,
64                                Type type) {
65   int widthOfSomething;
66   Type oneType;
67   SmallVector<int, 4> arrayOfInts;
68   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
69       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
70       parser.parseLSquare())
71     return Attribute();
72 
73   int intVal;
74   while (!*parser.parseOptionalInteger(intVal)) {
75     arrayOfInts.push_back(intVal);
76     if (parser.parseOptionalComma())
77       break;
78   }
79 
80   if (parser.parseRSquare() || parser.parseGreater())
81     return Attribute();
82   return get(context, widthOfSomething, oneType, arrayOfInts);
83 }
84 
85 void CompoundAAttr::print(DialectAsmPrinter &printer) const {
86   printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
87           << ", [";
88   llvm::interleaveComma(getArrayOfInts(), printer);
89   printer << "]>";
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // CompoundAAttr
94 //===----------------------------------------------------------------------===//
95 
96 Attribute TestI64ElementsAttr::parse(MLIRContext *context,
97                                      DialectAsmParser &parser, Type type) {
98   SmallVector<uint64_t> elements;
99   if (parser.parseLess() || parser.parseLSquare())
100     return Attribute();
101   uint64_t intVal;
102   while (succeeded(*parser.parseOptionalInteger(intVal))) {
103     elements.push_back(intVal);
104     if (parser.parseOptionalComma())
105       break;
106   }
107 
108   if (parser.parseRSquare() || parser.parseGreater())
109     return Attribute();
110   return parser.getChecked<TestI64ElementsAttr>(
111       context, type.cast<ShapedType>(), elements);
112 }
113 
114 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const {
115   printer << "i64_elements<[";
116   llvm::interleaveComma(getElements(), printer);
117   printer << "] : " << getType() << ">";
118 }
119 
120 LogicalResult
121 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
122                             ShapedType type, ArrayRef<uint64_t> elements) {
123   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
124     return emitError()
125            << "number of elements does not match the provided shape type, got: "
126            << elements.size() << ", but expected: " << type.getNumElements();
127   }
128   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
129     return emitError() << "expected single rank 64-bit shape type, but got: "
130                        << type;
131   return success();
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // Tablegen Generated Definitions
136 //===----------------------------------------------------------------------===//
137 
138 #include "TestAttrInterfaces.cpp.inc"
139 
140 #define GET_ATTRDEF_CLASSES
141 #include "TestAttrDefs.cpp.inc"
142 
143 //===----------------------------------------------------------------------===//
144 // TestDialect
145 //===----------------------------------------------------------------------===//
146 
147 void TestDialect::registerAttributes() {
148   addAttributes<
149 #define GET_ATTRDEF_LIST
150 #include "TestAttrDefs.cpp.inc"
151       >();
152 }
153 
154 Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
155                                       Type type) const {
156   StringRef attrTag;
157   if (failed(parser.parseKeyword(&attrTag)))
158     return Attribute();
159   {
160     Attribute attr;
161     auto parseResult =
162         generatedAttributeParser(getContext(), parser, attrTag, type, attr);
163     if (parseResult.hasValue())
164       return attr;
165   }
166   parser.emitError(parser.getNameLoc(), "unknown test attribute");
167   return Attribute();
168 }
169 
170 void TestDialect::printAttribute(Attribute attr,
171                                  DialectAsmPrinter &printer) const {
172   if (succeeded(generatedAttributePrinter(attr, printer)))
173     return;
174 }
175