1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(
21     Operation *op, StringRef name, unsigned maxNumElements, ArrayAttr attr,
22     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23   /// Check static and dynamic offsets/sizes/strides does not overflow type.
24   if (attr.size() > maxNumElements)
25     return op->emitError("expected <= ")
26            << maxNumElements << " " << name << " values";
27   unsigned expectedNumDynamicEntries =
28       llvm::count_if(attr.getValue(), [&](Attribute attr) {
29         return isDynamic(attr.cast<IntegerAttr>().getInt());
30       });
31   if (values.size() != expectedNumDynamicEntries)
32     return op->emitError("expected ")
33            << expectedNumDynamicEntries << " dynamic " << name << " values";
34   return success();
35 }
36 
37 LogicalResult
38 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
39   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
40   // Offsets can come in 2 flavors:
41   //   1. Either single entry (when maxRanks == 1).
42   //   2. Or as an array whose rank must match that of the mixed sizes.
43   // So that the result type is well-formed.
44   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) &&
45       op.getMixedOffsets().size() != op.getMixedSizes().size())
46     return op->emitError(
47                "expected mixed offsets rank to match mixed sizes rank (")
48            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
49            << ") so the rank of the result type is well-formed.";
50   // Ranks of mixed sizes and strides must always match so the result type is
51   // well-formed.
52   if (op.getMixedSizes().size() != op.getMixedStrides().size())
53     return op->emitError(
54                "expected mixed sizes rank to match mixed strides rank (")
55            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
56            << ") so the rank of the result type is well-formed.";
57 
58   if (failed(verifyListOfOperandsOrIntegers(
59           op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
60           ShapedType::isDynamicStrideOrOffset)))
61     return failure();
62   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
63                                             op.static_sizes(), op.sizes(),
64                                             ShapedType::isDynamic)))
65     return failure();
66   if (failed(verifyListOfOperandsOrIntegers(
67           op, "stride", maxRanks[2], op.static_strides(), op.strides(),
68           ShapedType::isDynamicStrideOrOffset)))
69     return failure();
70   return success();
71 }
72 
73 template <int64_t dynVal>
74 static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
75                                             ArrayAttr arrayAttr) {
76   p << '[';
77   if (arrayAttr.empty()) {
78     p << "]";
79     return;
80   }
81   unsigned idx = 0;
82   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
83     int64_t val = a.cast<IntegerAttr>().getInt();
84     if (val == dynVal)
85       p << values[idx++];
86     else
87       p << val;
88   });
89   p << ']';
90 }
91 
92 void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
93                                                        Operation *op,
94                                                        OperandRange values,
95                                                        ArrayAttr integers) {
96   return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
97       p, values, integers);
98 }
99 
100 void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
101                                             OperandRange values,
102                                             ArrayAttr integers) {
103   return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
104                                                                    integers);
105 }
106 
107 template <int64_t dynVal>
108 static ParseResult
109 parseOperandsOrIntegersImpl(OpAsmParser &parser,
110                             SmallVectorImpl<OpAsmParser::OperandType> &values,
111                             ArrayAttr &integers) {
112   if (failed(parser.parseLSquare()))
113     return failure();
114   // 0-D.
115   if (succeeded(parser.parseOptionalRSquare())) {
116     integers = parser.getBuilder().getArrayAttr({});
117     return success();
118   }
119 
120   SmallVector<int64_t, 4> attrVals;
121   while (true) {
122     OpAsmParser::OperandType operand;
123     auto res = parser.parseOptionalOperand(operand);
124     if (res.hasValue() && succeeded(res.getValue())) {
125       values.push_back(operand);
126       attrVals.push_back(dynVal);
127     } else {
128       IntegerAttr attr;
129       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
130         return parser.emitError(parser.getNameLoc())
131                << "expected SSA value or integer";
132       attrVals.push_back(attr.getInt());
133     }
134 
135     if (succeeded(parser.parseOptionalComma()))
136       continue;
137     if (failed(parser.parseRSquare()))
138       return failure();
139     break;
140   }
141   integers = parser.getBuilder().getI64ArrayAttr(attrVals);
142   return success();
143 }
144 
145 ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
146     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
147     ArrayAttr &integers) {
148   return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
149       parser, values, integers);
150 }
151 
152 ParseResult mlir::parseOperandsOrIntegersSizesList(
153     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
154     ArrayAttr &integers) {
155   return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
156                                                                integers);
157 }
158 
159 bool mlir::detail::sameOffsetsSizesAndStrides(
160     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
161     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
162   if (a.static_offsets().size() != b.static_offsets().size())
163     return false;
164   if (a.static_sizes().size() != b.static_sizes().size())
165     return false;
166   if (a.static_strides().size() != b.static_strides().size())
167     return false;
168   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
169     if (!cmp(std::get<0>(it), std::get<1>(it)))
170       return false;
171   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
172     if (!cmp(std::get<0>(it), std::get<1>(it)))
173       return false;
174   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
175     if (!cmp(std::get<0>(it), std::get<1>(it)))
176       return false;
177   return true;
178 }
179