1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(
21     Operation *op, StringRef name, unsigned maxNumElements, ArrayAttr attr,
22     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23   /// Check static and dynamic offsets/sizes/strides does not overflow type.
24   if (attr.size() > maxNumElements)
25     return op->emitError("expected <= ")
26            << maxNumElements << " " << name << " values";
27   unsigned expectedNumDynamicEntries =
28       llvm::count_if(attr.getValue(), [&](Attribute attr) {
29         return isDynamic(attr.cast<IntegerAttr>().getInt());
30       });
31   if (values.size() != expectedNumDynamicEntries)
32     return op->emitError("expected ")
33            << expectedNumDynamicEntries << " dynamic " << name << " values";
34   return success();
35 }
36 
37 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
38   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
39   // Offsets can come in 2 flavors:
40   //   1. Either single entry (when maxRanks == 1).
41   //   2. Or as an array whose rank must match that of the mixed sizes.
42   // So that the result type is well-formed.
43   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) &&
44       op.getMixedOffsets().size() != op.getMixedSizes().size())
45     return op->emitError(
46                "expected mixed offsets rank to match mixed sizes rank (")
47            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
48            << ") so the rank of the result type is well-formed.";
49   // Ranks of mixed sizes and strides must always match so the result type is
50   // well-formed.
51   if (op.getMixedSizes().size() != op.getMixedStrides().size())
52     return op->emitError(
53                "expected mixed sizes rank to match mixed strides rank (")
54            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
55            << ") so the rank of the result type is well-formed.";
56 
57   if (failed(verifyListOfOperandsOrIntegers(
58           op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
59           ShapedType::isDynamicStrideOrOffset)))
60     return failure();
61   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
62                                             op.static_sizes(), op.sizes(),
63                                             ShapedType::isDynamic)))
64     return failure();
65   if (failed(verifyListOfOperandsOrIntegers(
66           op, "stride", maxRanks[2], op.static_strides(), op.strides(),
67           ShapedType::isDynamicStrideOrOffset)))
68     return failure();
69   return success();
70 }
71 
72 template <int64_t dynVal>
73 static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
74                                             ArrayAttr arrayAttr) {
75   p << '[';
76   if (arrayAttr.empty()) {
77     p << "]";
78     return;
79   }
80   unsigned idx = 0;
81   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
82     int64_t val = a.cast<IntegerAttr>().getInt();
83     if (val == dynVal)
84       p << values[idx++];
85     else
86       p << val;
87   });
88   p << ']';
89 }
90 
91 void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
92                                                        Operation *op,
93                                                        OperandRange values,
94                                                        ArrayAttr integers) {
95   return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
96       p, values, integers);
97 }
98 
99 void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
100                                             OperandRange values,
101                                             ArrayAttr integers) {
102   return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
103                                                                    integers);
104 }
105 
106 template <int64_t dynVal>
107 static ParseResult
108 parseOperandsOrIntegersImpl(OpAsmParser &parser,
109                             SmallVectorImpl<OpAsmParser::OperandType> &values,
110                             ArrayAttr &integers) {
111   if (failed(parser.parseLSquare()))
112     return failure();
113   // 0-D.
114   if (succeeded(parser.parseOptionalRSquare())) {
115     integers = parser.getBuilder().getArrayAttr({});
116     return success();
117   }
118 
119   SmallVector<int64_t, 4> attrVals;
120   while (true) {
121     OpAsmParser::OperandType operand;
122     auto res = parser.parseOptionalOperand(operand);
123     if (res.hasValue() && succeeded(res.getValue())) {
124       values.push_back(operand);
125       attrVals.push_back(dynVal);
126     } else {
127       IntegerAttr attr;
128       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
129         return parser.emitError(parser.getNameLoc())
130                << "expected SSA value or integer";
131       attrVals.push_back(attr.getInt());
132     }
133 
134     if (succeeded(parser.parseOptionalComma()))
135       continue;
136     if (failed(parser.parseRSquare()))
137       return failure();
138     break;
139   }
140   integers = parser.getBuilder().getI64ArrayAttr(attrVals);
141   return success();
142 }
143 
144 ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
145     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
146     ArrayAttr &integers) {
147   return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
148       parser, values, integers);
149 }
150 
151 ParseResult mlir::parseOperandsOrIntegersSizesList(
152     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
153     ArrayAttr &integers) {
154   return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
155                                                                integers);
156 }
157