1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10
11 using namespace mlir;
12
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19
verifyListOfOperandsOrIntegers(Operation * op,StringRef name,unsigned numElements,ArrayAttr attr,ValueRange values,llvm::function_ref<bool (int64_t)> isDynamic)20 LogicalResult mlir::verifyListOfOperandsOrIntegers(
21 Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
22 ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23 /// Check static and dynamic offsets/sizes/strides does not overflow type.
24 if (attr.size() != numElements)
25 return op->emitError("expected ")
26 << numElements << " " << name << " values";
27 unsigned expectedNumDynamicEntries =
28 llvm::count_if(attr.getValue(), [&](Attribute attr) {
29 return isDynamic(attr.cast<IntegerAttr>().getInt());
30 });
31 if (values.size() != expectedNumDynamicEntries)
32 return op->emitError("expected ")
33 << expectedNumDynamicEntries << " dynamic " << name << " values";
34 return success();
35 }
36
37 LogicalResult
verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)38 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
39 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
40 // Offsets can come in 2 flavors:
41 // 1. Either single entry (when maxRanks == 1).
42 // 2. Or as an array whose rank must match that of the mixed sizes.
43 // So that the result type is well-formed.
44 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
45 op.getMixedOffsets().size() != op.getMixedSizes().size())
46 return op->emitError(
47 "expected mixed offsets rank to match mixed sizes rank (")
48 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
49 << ") so the rank of the result type is well-formed.";
50 // Ranks of mixed sizes and strides must always match so the result type is
51 // well-formed.
52 if (op.getMixedSizes().size() != op.getMixedStrides().size())
53 return op->emitError(
54 "expected mixed sizes rank to match mixed strides rank (")
55 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
56 << ") so the rank of the result type is well-formed.";
57
58 if (failed(verifyListOfOperandsOrIntegers(
59 op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
60 ShapedType::isDynamicStrideOrOffset)))
61 return failure();
62 if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
63 op.static_sizes(), op.sizes(),
64 ShapedType::isDynamic)))
65 return failure();
66 if (failed(verifyListOfOperandsOrIntegers(
67 op, "stride", maxRanks[2], op.static_strides(), op.strides(),
68 ShapedType::isDynamicStrideOrOffset)))
69 return failure();
70 return success();
71 }
72
73 template <int64_t dynVal>
printOperandsOrIntegersListImpl(OpAsmPrinter & p,ValueRange values,ArrayAttr arrayAttr)74 static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
75 ArrayAttr arrayAttr) {
76 p << '[';
77 if (arrayAttr.empty()) {
78 p << "]";
79 return;
80 }
81 unsigned idx = 0;
82 llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
83 int64_t val = a.cast<IntegerAttr>().getInt();
84 if (val == dynVal)
85 p << values[idx++];
86 else
87 p << val;
88 });
89 p << ']';
90 }
91
printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter & p,Operation * op,OperandRange values,ArrayAttr integers)92 void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
93 Operation *op,
94 OperandRange values,
95 ArrayAttr integers) {
96 return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
97 p, values, integers);
98 }
99
printOperandsOrIntegersSizesList(OpAsmPrinter & p,Operation * op,OperandRange values,ArrayAttr integers)100 void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
101 OperandRange values,
102 ArrayAttr integers) {
103 return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
104 integers);
105 }
106
107 template <int64_t dynVal>
parseOperandsOrIntegersImpl(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)108 static ParseResult parseOperandsOrIntegersImpl(
109 OpAsmParser &parser,
110 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
111 ArrayAttr &integers) {
112 if (failed(parser.parseLSquare()))
113 return failure();
114 // 0-D.
115 if (succeeded(parser.parseOptionalRSquare())) {
116 integers = parser.getBuilder().getArrayAttr({});
117 return success();
118 }
119
120 SmallVector<int64_t, 4> attrVals;
121 while (true) {
122 OpAsmParser::UnresolvedOperand operand;
123 auto res = parser.parseOptionalOperand(operand);
124 if (res.hasValue() && succeeded(res.getValue())) {
125 values.push_back(operand);
126 attrVals.push_back(dynVal);
127 } else {
128 IntegerAttr attr;
129 if (failed(parser.parseAttribute<IntegerAttr>(attr)))
130 return parser.emitError(parser.getNameLoc())
131 << "expected SSA value or integer";
132 attrVals.push_back(attr.getInt());
133 }
134
135 if (succeeded(parser.parseOptionalComma()))
136 continue;
137 if (failed(parser.parseRSquare()))
138 return failure();
139 break;
140 }
141 integers = parser.getBuilder().getI64ArrayAttr(attrVals);
142 return success();
143 }
144
parseOperandsOrIntegersOffsetsOrStridesList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)145 ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
146 OpAsmParser &parser,
147 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
148 ArrayAttr &integers) {
149 return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
150 parser, values, integers);
151 }
152
parseOperandsOrIntegersSizesList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)153 ParseResult mlir::parseOperandsOrIntegersSizesList(
154 OpAsmParser &parser,
155 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
156 ArrayAttr &integers) {
157 return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
158 integers);
159 }
160
sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a,OffsetSizeAndStrideOpInterface b,llvm::function_ref<bool (OpFoldResult,OpFoldResult)> cmp)161 bool mlir::detail::sameOffsetsSizesAndStrides(
162 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
163 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
164 if (a.static_offsets().size() != b.static_offsets().size())
165 return false;
166 if (a.static_sizes().size() != b.static_sizes().size())
167 return false;
168 if (a.static_strides().size() != b.static_strides().size())
169 return false;
170 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
171 if (!cmp(std::get<0>(it), std::get<1>(it)))
172 return false;
173 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
174 if (!cmp(std::get<0>(it), std::get<1>(it)))
175 return false;
176 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
177 if (!cmp(std::get<0>(it), std::get<1>(it)))
178 return false;
179 return true;
180 }
181
182 SmallVector<OpFoldResult, 4>
getMixedOffsets(OffsetSizeAndStrideOpInterface op,ArrayAttr staticOffsets,ValueRange offsets)183 mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
184 ArrayAttr staticOffsets, ValueRange offsets) {
185 SmallVector<OpFoldResult, 4> res;
186 unsigned numDynamic = 0;
187 unsigned count = static_cast<unsigned>(staticOffsets.size());
188 for (unsigned idx = 0; idx < count; ++idx) {
189 if (op.isDynamicOffset(idx))
190 res.push_back(offsets[numDynamic++]);
191 else
192 res.push_back(staticOffsets[idx]);
193 }
194 return res;
195 }
196
197 SmallVector<OpFoldResult, 4>
getMixedSizes(OffsetSizeAndStrideOpInterface op,ArrayAttr staticSizes,ValueRange sizes)198 mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
199 ValueRange sizes) {
200 SmallVector<OpFoldResult, 4> res;
201 unsigned numDynamic = 0;
202 unsigned count = static_cast<unsigned>(staticSizes.size());
203 for (unsigned idx = 0; idx < count; ++idx) {
204 if (op.isDynamicSize(idx))
205 res.push_back(sizes[numDynamic++]);
206 else
207 res.push_back(staticSizes[idx]);
208 }
209 return res;
210 }
211
212 SmallVector<OpFoldResult, 4>
getMixedStrides(OffsetSizeAndStrideOpInterface op,ArrayAttr staticStrides,ValueRange strides)213 mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
214 ArrayAttr staticStrides, ValueRange strides) {
215 SmallVector<OpFoldResult, 4> res;
216 unsigned numDynamic = 0;
217 unsigned count = static_cast<unsigned>(staticStrides.size());
218 for (unsigned idx = 0; idx < count; ++idx) {
219 if (op.isDynamicStride(idx))
220 res.push_back(strides[numDynamic++]);
221 else
222 res.push_back(staticStrides[idx]);
223 }
224 return res;
225 }
226
227 static std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedImpl(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues,const int64_t dynamicValuePlaceholder)228 decomposeMixedImpl(OpBuilder &b,
229 const SmallVectorImpl<OpFoldResult> &mixedValues,
230 const int64_t dynamicValuePlaceholder) {
231 SmallVector<int64_t> staticValues;
232 SmallVector<Value> dynamicValues;
233 for (const auto &it : mixedValues) {
234 if (it.is<Attribute>()) {
235 staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
236 } else {
237 staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
238 dynamicValues.push_back(it.get<Value>());
239 }
240 }
241 return {b.getI64ArrayAttr(staticValues), dynamicValues};
242 }
243
decomposeMixedStridesOrOffsets(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues)244 std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
245 OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
246 return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
247 }
248
249 std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedSizes(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues)250 mlir::decomposeMixedSizes(OpBuilder &b,
251 const SmallVectorImpl<OpFoldResult> &mixedValues) {
252 return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
253 }
254