183ef862fSRiver Riddle //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
283ef862fSRiver Riddle //
383ef862fSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
483ef862fSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
583ef862fSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
683ef862fSRiver Riddle //
783ef862fSRiver Riddle //===----------------------------------------------------------------------===//
883ef862fSRiver Riddle //
983ef862fSRiver Riddle // This file contains attributes defined by the TestDialect for testing various
1083ef862fSRiver Riddle // features of MLIR.
1183ef862fSRiver Riddle //
1283ef862fSRiver Riddle //===----------------------------------------------------------------------===//
1383ef862fSRiver Riddle 
1483ef862fSRiver Riddle #include "TestAttributes.h"
1583ef862fSRiver Riddle #include "TestDialect.h"
1683ef862fSRiver Riddle #include "mlir/IR/Builders.h"
1783ef862fSRiver Riddle #include "mlir/IR/DialectImplementation.h"
189e0b5533SMathieu Fehr #include "mlir/IR/ExtensibleDialect.h"
1983ef862fSRiver Riddle #include "mlir/IR/Types.h"
209a2fdc36SJeff Niu #include "mlir/Support/LogicalResult.h"
2183ef862fSRiver Riddle #include "llvm/ADT/Hashing.h"
2283ef862fSRiver Riddle #include "llvm/ADT/SetVector.h"
2383ef862fSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
249a2fdc36SJeff Niu #include "llvm/ADT/bit.h"
259e0b5533SMathieu Fehr #include "llvm/Support/ErrorHandling.h"
2683ef862fSRiver Riddle 
2783ef862fSRiver Riddle using namespace mlir;
287776b19eSStephen Neuendorffer using namespace test;
2983ef862fSRiver Riddle 
301447ec51SRiver Riddle //===----------------------------------------------------------------------===//
311447ec51SRiver Riddle // AttrWithTypeBuilderAttr
321447ec51SRiver Riddle //===----------------------------------------------------------------------===//
331447ec51SRiver Riddle 
parse(AsmParser & parser,Type type)34f97e72aaSMehdi Amini Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
351447ec51SRiver Riddle   IntegerAttr element;
361447ec51SRiver Riddle   if (parser.parseAttribute(element))
371447ec51SRiver Riddle     return Attribute();
38fb093c83SChris Lattner   return get(parser.getContext(), element);
391447ec51SRiver Riddle }
401447ec51SRiver Riddle 
print(AsmPrinter & printer) const41f97e72aaSMehdi Amini void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
42f30a8a6fSMehdi Amini   printer << " " << getAttr();
431447ec51SRiver Riddle }
441447ec51SRiver Riddle 
451447ec51SRiver Riddle //===----------------------------------------------------------------------===//
461447ec51SRiver Riddle // CompoundAAttr
471447ec51SRiver Riddle //===----------------------------------------------------------------------===//
481447ec51SRiver Riddle 
parse(AsmParser & parser,Type type)49f97e72aaSMehdi Amini Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
5083ef862fSRiver Riddle   int widthOfSomething;
5183ef862fSRiver Riddle   Type oneType;
5283ef862fSRiver Riddle   SmallVector<int, 4> arrayOfInts;
5383ef862fSRiver Riddle   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
5483ef862fSRiver Riddle       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
5583ef862fSRiver Riddle       parser.parseLSquare())
5683ef862fSRiver Riddle     return Attribute();
5783ef862fSRiver Riddle 
5883ef862fSRiver Riddle   int intVal;
5983ef862fSRiver Riddle   while (!*parser.parseOptionalInteger(intVal)) {
6083ef862fSRiver Riddle     arrayOfInts.push_back(intVal);
6183ef862fSRiver Riddle     if (parser.parseOptionalComma())
6283ef862fSRiver Riddle       break;
6383ef862fSRiver Riddle   }
6483ef862fSRiver Riddle 
6583ef862fSRiver Riddle   if (parser.parseRSquare() || parser.parseGreater())
6683ef862fSRiver Riddle     return Attribute();
67fb093c83SChris Lattner   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
6883ef862fSRiver Riddle }
6983ef862fSRiver Riddle 
print(AsmPrinter & printer) const70f97e72aaSMehdi Amini void CompoundAAttr::print(AsmPrinter &printer) const {
71f30a8a6fSMehdi Amini   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
7283ef862fSRiver Riddle   llvm::interleaveComma(getArrayOfInts(), printer);
7383ef862fSRiver Riddle   printer << "]>";
7483ef862fSRiver Riddle }
7583ef862fSRiver Riddle 
7683ef862fSRiver Riddle //===----------------------------------------------------------------------===//
77d80d3a35SRiver Riddle // CompoundAAttr
78d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
79d80d3a35SRiver Riddle 
parse(AsmParser & parser,Type type)80f97e72aaSMehdi Amini Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
81d80d3a35SRiver Riddle   SmallVector<uint64_t> elements;
82d80d3a35SRiver Riddle   if (parser.parseLess() || parser.parseLSquare())
83d80d3a35SRiver Riddle     return Attribute();
84d80d3a35SRiver Riddle   uint64_t intVal;
85d80d3a35SRiver Riddle   while (succeeded(*parser.parseOptionalInteger(intVal))) {
86d80d3a35SRiver Riddle     elements.push_back(intVal);
87d80d3a35SRiver Riddle     if (parser.parseOptionalComma())
88d80d3a35SRiver Riddle       break;
89d80d3a35SRiver Riddle   }
90d80d3a35SRiver Riddle 
91d80d3a35SRiver Riddle   if (parser.parseRSquare() || parser.parseGreater())
92d80d3a35SRiver Riddle     return Attribute();
93d80d3a35SRiver Riddle   return parser.getChecked<TestI64ElementsAttr>(
94fb093c83SChris Lattner       parser.getContext(), type.cast<ShapedType>(), elements);
95d80d3a35SRiver Riddle }
96d80d3a35SRiver Riddle 
print(AsmPrinter & printer) const97f97e72aaSMehdi Amini void TestI64ElementsAttr::print(AsmPrinter &printer) const {
98f30a8a6fSMehdi Amini   printer << "<[";
99d80d3a35SRiver Riddle   llvm::interleaveComma(getElements(), printer);
100d80d3a35SRiver Riddle   printer << "] : " << getType() << ">";
101d80d3a35SRiver Riddle }
102d80d3a35SRiver Riddle 
103d80d3a35SRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ShapedType type,ArrayRef<uint64_t> elements)104d80d3a35SRiver Riddle TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
105d80d3a35SRiver Riddle                             ShapedType type, ArrayRef<uint64_t> elements) {
106d80d3a35SRiver Riddle   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
107d80d3a35SRiver Riddle     return emitError()
108d80d3a35SRiver Riddle            << "number of elements does not match the provided shape type, got: "
109d80d3a35SRiver Riddle            << elements.size() << ", but expected: " << type.getNumElements();
110d80d3a35SRiver Riddle   }
111d80d3a35SRiver Riddle   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
112d80d3a35SRiver Riddle     return emitError() << "expected single rank 64-bit shape type, but got: "
113d80d3a35SRiver Riddle                        << type;
114d80d3a35SRiver Riddle   return success();
115d80d3a35SRiver Riddle }
116d80d3a35SRiver Riddle 
verify(function_ref<InFlightDiagnostic ()> emitError,int64_t one,std::string two,IntegerAttr three,ArrayRef<int> four,ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr)117f68454eeSMehdi Amini LogicalResult TestAttrWithFormatAttr::verify(
118f68454eeSMehdi Amini     function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
119f68454eeSMehdi Amini     IntegerAttr three, ArrayRef<int> four,
120f68454eeSMehdi Amini     ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) {
1219a2fdc36SJeff Niu   if (four.size() != static_cast<unsigned>(one))
1229a2fdc36SJeff Niu     return emitError() << "expected 'one' to equal 'four.size()'";
1239a2fdc36SJeff Niu   return success();
1249a2fdc36SJeff Niu }
1259a2fdc36SJeff Niu 
1269a2fdc36SJeff Niu //===----------------------------------------------------------------------===//
1279a2fdc36SJeff Niu // Utility Functions for Generated Attributes
1289a2fdc36SJeff Niu //===----------------------------------------------------------------------===//
1299a2fdc36SJeff Niu 
parseIntArray(AsmParser & parser)130f97e72aaSMehdi Amini static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
1319a2fdc36SJeff Niu   SmallVector<int> ints;
1329a2fdc36SJeff Niu   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
1339a2fdc36SJeff Niu         ints.push_back(0);
1349a2fdc36SJeff Niu         return parser.parseInteger(ints.back());
1359a2fdc36SJeff Niu       }) ||
1369a2fdc36SJeff Niu       parser.parseRSquare())
1379a2fdc36SJeff Niu     return failure();
1389a2fdc36SJeff Niu   return ints;
1399a2fdc36SJeff Niu }
1409a2fdc36SJeff Niu 
printIntArray(AsmPrinter & printer,ArrayRef<int> ints)141f97e72aaSMehdi Amini static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
1429a2fdc36SJeff Niu   printer << '[';
1439a2fdc36SJeff Niu   llvm::interleaveComma(ints, printer);
1449a2fdc36SJeff Niu   printer << ']';
1459a2fdc36SJeff Niu }
1469a2fdc36SJeff Niu 
147d80d3a35SRiver Riddle //===----------------------------------------------------------------------===//
14810a80c44SMarkus Böck // TestSubElementsAccessAttr
14910a80c44SMarkus Böck //===----------------------------------------------------------------------===//
15010a80c44SMarkus Böck 
parse(::mlir::AsmParser & parser,::mlir::Type type)151f97e72aaSMehdi Amini Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
15210a80c44SMarkus Böck                                            ::mlir::Type type) {
15310a80c44SMarkus Böck   Attribute first, second, third;
15410a80c44SMarkus Böck   if (parser.parseLess() || parser.parseAttribute(first) ||
15510a80c44SMarkus Böck       parser.parseComma() || parser.parseAttribute(second) ||
15610a80c44SMarkus Böck       parser.parseComma() || parser.parseAttribute(third) ||
15710a80c44SMarkus Böck       parser.parseGreater()) {
15810a80c44SMarkus Böck     return {};
15910a80c44SMarkus Böck   }
16010a80c44SMarkus Böck   return get(parser.getContext(), first, second, third);
16110a80c44SMarkus Böck }
16210a80c44SMarkus Böck 
print(::mlir::AsmPrinter & printer) const163f97e72aaSMehdi Amini void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
164f30a8a6fSMehdi Amini   printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
165f30a8a6fSMehdi Amini           << ">";
16610a80c44SMarkus Böck }
16710a80c44SMarkus Böck 
walkImmediateSubElements(llvm::function_ref<void (mlir::Attribute)> walkAttrsFn,llvm::function_ref<void (mlir::Type)> walkTypesFn) const16810a80c44SMarkus Böck void TestSubElementsAccessAttr::walkImmediateSubElements(
16910a80c44SMarkus Böck     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
17010a80c44SMarkus Böck     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
17110a80c44SMarkus Böck   walkAttrsFn(getFirst());
17210a80c44SMarkus Böck   walkAttrsFn(getSecond());
17310a80c44SMarkus Böck   walkAttrsFn(getThird());
17410a80c44SMarkus Böck }
17510a80c44SMarkus Böck 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const176*01eedbc7SRiver Riddle Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
177*01eedbc7SRiver Riddle     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
178*01eedbc7SRiver Riddle   assert(replAttrs.size() == 3 && "invalid number of replacement attributes");
179*01eedbc7SRiver Riddle   return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]);
18010a80c44SMarkus Böck }
18110a80c44SMarkus Böck 
18210a80c44SMarkus Böck //===----------------------------------------------------------------------===//
183ea488bd6SRiver Riddle // TestExtern1DI64ElementsAttr
184ea488bd6SRiver Riddle //===----------------------------------------------------------------------===//
185ea488bd6SRiver Riddle 
getElements() const186ea488bd6SRiver Riddle ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
187ea488bd6SRiver Riddle   return getHandle().getData()->getData();
188ea488bd6SRiver Riddle }
189ea488bd6SRiver Riddle 
190ea488bd6SRiver Riddle //===----------------------------------------------------------------------===//
19183ef862fSRiver Riddle // Tablegen Generated Definitions
19283ef862fSRiver Riddle //===----------------------------------------------------------------------===//
19383ef862fSRiver Riddle 
1949b2a1bcfSAlex Zinenko #include "TestAttrInterfaces.cpp.inc"
1959b2a1bcfSAlex Zinenko 
19683ef862fSRiver Riddle #define GET_ATTRDEF_CLASSES
19783ef862fSRiver Riddle #include "TestAttrDefs.cpp.inc"
19883ef862fSRiver Riddle 
19983ef862fSRiver Riddle //===----------------------------------------------------------------------===//
2009e0b5533SMathieu Fehr // Dynamic Attributes
2019e0b5533SMathieu Fehr //===----------------------------------------------------------------------===//
2029e0b5533SMathieu Fehr 
2039e0b5533SMathieu Fehr /// Define a singleton dynamic attribute.
2049e0b5533SMathieu Fehr static std::unique_ptr<DynamicAttrDefinition>
getDynamicSingletonAttr(TestDialect * testDialect)2059e0b5533SMathieu Fehr getDynamicSingletonAttr(TestDialect *testDialect) {
2069e0b5533SMathieu Fehr   return DynamicAttrDefinition::get(
2079e0b5533SMathieu Fehr       "dynamic_singleton", testDialect,
2089e0b5533SMathieu Fehr       [](function_ref<InFlightDiagnostic()> emitError,
2099e0b5533SMathieu Fehr          ArrayRef<Attribute> args) {
2109e0b5533SMathieu Fehr         if (!args.empty()) {
2119e0b5533SMathieu Fehr           emitError() << "expected 0 attribute arguments, but had "
2129e0b5533SMathieu Fehr                       << args.size();
2139e0b5533SMathieu Fehr           return failure();
2149e0b5533SMathieu Fehr         }
2159e0b5533SMathieu Fehr         return success();
2169e0b5533SMathieu Fehr       });
2179e0b5533SMathieu Fehr }
2189e0b5533SMathieu Fehr 
2199e0b5533SMathieu Fehr /// Define a dynamic attribute representing a pair or attributes.
2209e0b5533SMathieu Fehr static std::unique_ptr<DynamicAttrDefinition>
getDynamicPairAttr(TestDialect * testDialect)2219e0b5533SMathieu Fehr getDynamicPairAttr(TestDialect *testDialect) {
2229e0b5533SMathieu Fehr   return DynamicAttrDefinition::get(
2239e0b5533SMathieu Fehr       "dynamic_pair", testDialect,
2249e0b5533SMathieu Fehr       [](function_ref<InFlightDiagnostic()> emitError,
2259e0b5533SMathieu Fehr          ArrayRef<Attribute> args) {
2269e0b5533SMathieu Fehr         if (args.size() != 2) {
2279e0b5533SMathieu Fehr           emitError() << "expected 2 attribute arguments, but had "
2289e0b5533SMathieu Fehr                       << args.size();
2299e0b5533SMathieu Fehr           return failure();
2309e0b5533SMathieu Fehr         }
2319e0b5533SMathieu Fehr         return success();
2329e0b5533SMathieu Fehr       });
2339e0b5533SMathieu Fehr }
2349e0b5533SMathieu Fehr 
2359e0b5533SMathieu Fehr static std::unique_ptr<DynamicAttrDefinition>
getDynamicCustomAssemblyFormatAttr(TestDialect * testDialect)2369e0b5533SMathieu Fehr getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
2379e0b5533SMathieu Fehr   auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
2389e0b5533SMathieu Fehr                      ArrayRef<Attribute> args) {
2399e0b5533SMathieu Fehr     if (args.size() != 2) {
2409e0b5533SMathieu Fehr       emitError() << "expected 2 attribute arguments, but had " << args.size();
2419e0b5533SMathieu Fehr       return failure();
2429e0b5533SMathieu Fehr     }
2439e0b5533SMathieu Fehr     return success();
2449e0b5533SMathieu Fehr   };
2459e0b5533SMathieu Fehr 
2469e0b5533SMathieu Fehr   auto parser = [](AsmParser &parser,
2479e0b5533SMathieu Fehr                    llvm::SmallVectorImpl<Attribute> &parsedParams) {
2489e0b5533SMathieu Fehr     Attribute leftAttr, rightAttr;
2499e0b5533SMathieu Fehr     if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
2509e0b5533SMathieu Fehr         parser.parseColon() || parser.parseAttribute(rightAttr) ||
2519e0b5533SMathieu Fehr         parser.parseGreater())
2529e0b5533SMathieu Fehr       return failure();
2539e0b5533SMathieu Fehr     parsedParams.push_back(leftAttr);
2549e0b5533SMathieu Fehr     parsedParams.push_back(rightAttr);
2559e0b5533SMathieu Fehr     return success();
2569e0b5533SMathieu Fehr   };
2579e0b5533SMathieu Fehr 
2589e0b5533SMathieu Fehr   auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
2599e0b5533SMathieu Fehr     printer << "<" << params[0] << ":" << params[1] << ">";
2609e0b5533SMathieu Fehr   };
2619e0b5533SMathieu Fehr 
2629e0b5533SMathieu Fehr   return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
2639e0b5533SMathieu Fehr                                     testDialect, std::move(verifier),
2649e0b5533SMathieu Fehr                                     std::move(parser), std::move(printer));
2659e0b5533SMathieu Fehr }
2669e0b5533SMathieu Fehr 
2679e0b5533SMathieu Fehr //===----------------------------------------------------------------------===//
26883ef862fSRiver Riddle // TestDialect
26983ef862fSRiver Riddle //===----------------------------------------------------------------------===//
27083ef862fSRiver Riddle 
registerAttributes()27131bb8efdSRiver Riddle void TestDialect::registerAttributes() {
27231bb8efdSRiver Riddle   addAttributes<
27331bb8efdSRiver Riddle #define GET_ATTRDEF_LIST
27431bb8efdSRiver Riddle #include "TestAttrDefs.cpp.inc"
27531bb8efdSRiver Riddle       >();
2769e0b5533SMathieu Fehr   registerDynamicAttr(getDynamicSingletonAttr(this));
2779e0b5533SMathieu Fehr   registerDynamicAttr(getDynamicPairAttr(this));
2789e0b5533SMathieu Fehr   registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
27931bb8efdSRiver Riddle }
280