12458cd27SLei Zhang //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
22458cd27SLei Zhang //
32458cd27SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42458cd27SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52458cd27SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62458cd27SLei Zhang //
72458cd27SLei Zhang //===----------------------------------------------------------------------===//
82458cd27SLei Zhang
92458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.h"
102458cd27SLei Zhang
112458cd27SLei Zhang using namespace mlir;
122458cd27SLei Zhang
132458cd27SLei Zhang //===----------------------------------------------------------------------===//
142458cd27SLei Zhang // ViewLike Interfaces
152458cd27SLei Zhang //===----------------------------------------------------------------------===//
162458cd27SLei Zhang
172458cd27SLei Zhang /// Include the definitions of the loop-like interfaces.
182458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19a8de412fSNicolas Vasilache
verifyListOfOperandsOrIntegers(Operation * op,StringRef name,unsigned numElements,ArrayAttr attr,ValueRange values,llvm::function_ref<bool (int64_t)> isDynamic)20118a7156SMaheshRavishankar LogicalResult mlir::verifyListOfOperandsOrIntegers(
217df7586aSMaheshRavishankar Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
22118a7156SMaheshRavishankar ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
235133673dSNicolas Vasilache /// Check static and dynamic offsets/sizes/strides does not overflow type.
247df7586aSMaheshRavishankar if (attr.size() != numElements)
257df7586aSMaheshRavishankar return op->emitError("expected ")
267df7586aSMaheshRavishankar << numElements << " " << name << " values";
27a8de412fSNicolas Vasilache unsigned expectedNumDynamicEntries =
28a8de412fSNicolas Vasilache llvm::count_if(attr.getValue(), [&](Attribute attr) {
29a8de412fSNicolas Vasilache return isDynamic(attr.cast<IntegerAttr>().getInt());
30a8de412fSNicolas Vasilache });
31a8de412fSNicolas Vasilache if (values.size() != expectedNumDynamicEntries)
32118a7156SMaheshRavishankar return op->emitError("expected ")
33a8de412fSNicolas Vasilache << expectedNumDynamicEntries << " dynamic " << name << " values";
34a8de412fSNicolas Vasilache return success();
35a8de412fSNicolas Vasilache }
36a8de412fSNicolas Vasilache
3762851ea7SUday Bondhugula LogicalResult
verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)3862851ea7SUday Bondhugula mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
395133673dSNicolas Vasilache std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
405133673dSNicolas Vasilache // Offsets can come in 2 flavors:
415133673dSNicolas Vasilache // 1. Either single entry (when maxRanks == 1).
425133673dSNicolas Vasilache // 2. Or as an array whose rank must match that of the mixed sizes.
435133673dSNicolas Vasilache // So that the result type is well-formed.
44*84124ff8SMehdi Amini if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
455133673dSNicolas Vasilache op.getMixedOffsets().size() != op.getMixedSizes().size())
465133673dSNicolas Vasilache return op->emitError(
475133673dSNicolas Vasilache "expected mixed offsets rank to match mixed sizes rank (")
485133673dSNicolas Vasilache << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
495133673dSNicolas Vasilache << ") so the rank of the result type is well-formed.";
505133673dSNicolas Vasilache // Ranks of mixed sizes and strides must always match so the result type is
515133673dSNicolas Vasilache // well-formed.
525133673dSNicolas Vasilache if (op.getMixedSizes().size() != op.getMixedStrides().size())
535133673dSNicolas Vasilache return op->emitError(
545133673dSNicolas Vasilache "expected mixed sizes rank to match mixed strides rank (")
555133673dSNicolas Vasilache << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
565133673dSNicolas Vasilache << ") so the rank of the result type is well-formed.";
575133673dSNicolas Vasilache
58118a7156SMaheshRavishankar if (failed(verifyListOfOperandsOrIntegers(
595133673dSNicolas Vasilache op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
60118a7156SMaheshRavishankar ShapedType::isDynamicStrideOrOffset)))
61a8de412fSNicolas Vasilache return failure();
625133673dSNicolas Vasilache if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
63118a7156SMaheshRavishankar op.static_sizes(), op.sizes(),
64118a7156SMaheshRavishankar ShapedType::isDynamic)))
65a8de412fSNicolas Vasilache return failure();
66118a7156SMaheshRavishankar if (failed(verifyListOfOperandsOrIntegers(
675133673dSNicolas Vasilache op, "stride", maxRanks[2], op.static_strides(), op.strides(),
68118a7156SMaheshRavishankar ShapedType::isDynamicStrideOrOffset)))
69a8de412fSNicolas Vasilache return failure();
70a8de412fSNicolas Vasilache return success();
71a8de412fSNicolas Vasilache }
72b6c71c13SNicolas Vasilache
73342d4662SMaheshRavishankar template <int64_t dynVal>
printOperandsOrIntegersListImpl(OpAsmPrinter & p,ValueRange values,ArrayAttr arrayAttr)74342d4662SMaheshRavishankar static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
75342d4662SMaheshRavishankar ArrayAttr arrayAttr) {
76c2470810SNicolas Vasilache p << '[';
77342d4662SMaheshRavishankar if (arrayAttr.empty()) {
78342d4662SMaheshRavishankar p << "]";
79342d4662SMaheshRavishankar return;
80342d4662SMaheshRavishankar }
81c2470810SNicolas Vasilache unsigned idx = 0;
82c2470810SNicolas Vasilache llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
83c2470810SNicolas Vasilache int64_t val = a.cast<IntegerAttr>().getInt();
84342d4662SMaheshRavishankar if (val == dynVal)
85c2470810SNicolas Vasilache p << values[idx++];
86c2470810SNicolas Vasilache else
87c2470810SNicolas Vasilache p << val;
88c2470810SNicolas Vasilache });
89c2470810SNicolas Vasilache p << ']';
90c2470810SNicolas Vasilache }
91c2470810SNicolas Vasilache
printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter & p,Operation * op,OperandRange values,ArrayAttr integers)92342d4662SMaheshRavishankar void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
93342d4662SMaheshRavishankar Operation *op,
94342d4662SMaheshRavishankar OperandRange values,
95342d4662SMaheshRavishankar ArrayAttr integers) {
96342d4662SMaheshRavishankar return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
97342d4662SMaheshRavishankar p, values, integers);
98c2470810SNicolas Vasilache }
99c2470810SNicolas Vasilache
printOperandsOrIntegersSizesList(OpAsmPrinter & p,Operation * op,OperandRange values,ArrayAttr integers)100342d4662SMaheshRavishankar void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
101342d4662SMaheshRavishankar OperandRange values,
102342d4662SMaheshRavishankar ArrayAttr integers) {
103342d4662SMaheshRavishankar return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
104342d4662SMaheshRavishankar integers);
105342d4662SMaheshRavishankar }
106342d4662SMaheshRavishankar
107342d4662SMaheshRavishankar template <int64_t dynVal>
parseOperandsOrIntegersImpl(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)108e13d23bcSMarkus Böck static ParseResult parseOperandsOrIntegersImpl(
109e13d23bcSMarkus Böck OpAsmParser &parser,
110e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
111342d4662SMaheshRavishankar ArrayAttr &integers) {
112b6c71c13SNicolas Vasilache if (failed(parser.parseLSquare()))
113b6c71c13SNicolas Vasilache return failure();
114b6c71c13SNicolas Vasilache // 0-D.
115b6c71c13SNicolas Vasilache if (succeeded(parser.parseOptionalRSquare())) {
116342d4662SMaheshRavishankar integers = parser.getBuilder().getArrayAttr({});
117b6c71c13SNicolas Vasilache return success();
118b6c71c13SNicolas Vasilache }
119b6c71c13SNicolas Vasilache
120b6c71c13SNicolas Vasilache SmallVector<int64_t, 4> attrVals;
121b6c71c13SNicolas Vasilache while (true) {
122e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand operand;
123b6c71c13SNicolas Vasilache auto res = parser.parseOptionalOperand(operand);
1243b7c3a65SKazu Hirata if (res.hasValue() && succeeded(res.getValue())) {
125342d4662SMaheshRavishankar values.push_back(operand);
126b6c71c13SNicolas Vasilache attrVals.push_back(dynVal);
127b6c71c13SNicolas Vasilache } else {
128b6c71c13SNicolas Vasilache IntegerAttr attr;
129b6c71c13SNicolas Vasilache if (failed(parser.parseAttribute<IntegerAttr>(attr)))
130b6c71c13SNicolas Vasilache return parser.emitError(parser.getNameLoc())
131b6c71c13SNicolas Vasilache << "expected SSA value or integer";
132b6c71c13SNicolas Vasilache attrVals.push_back(attr.getInt());
133b6c71c13SNicolas Vasilache }
134b6c71c13SNicolas Vasilache
135b6c71c13SNicolas Vasilache if (succeeded(parser.parseOptionalComma()))
136b6c71c13SNicolas Vasilache continue;
137b6c71c13SNicolas Vasilache if (failed(parser.parseRSquare()))
138b6c71c13SNicolas Vasilache return failure();
139b6c71c13SNicolas Vasilache break;
140b6c71c13SNicolas Vasilache }
141342d4662SMaheshRavishankar integers = parser.getBuilder().getI64ArrayAttr(attrVals);
142b6c71c13SNicolas Vasilache return success();
143b6c71c13SNicolas Vasilache }
144b6c71c13SNicolas Vasilache
parseOperandsOrIntegersOffsetsOrStridesList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)145342d4662SMaheshRavishankar ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
146e13d23bcSMarkus Böck OpAsmParser &parser,
147e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
148342d4662SMaheshRavishankar ArrayAttr &integers) {
149342d4662SMaheshRavishankar return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
150342d4662SMaheshRavishankar parser, values, integers);
151c2470810SNicolas Vasilache }
152c2470810SNicolas Vasilache
parseOperandsOrIntegersSizesList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)153342d4662SMaheshRavishankar ParseResult mlir::parseOperandsOrIntegersSizesList(
154e13d23bcSMarkus Böck OpAsmParser &parser,
155e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
156342d4662SMaheshRavishankar ArrayAttr &integers) {
157342d4662SMaheshRavishankar return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
158342d4662SMaheshRavishankar integers);
159b6c71c13SNicolas Vasilache }
160ce4f99e7SNicolas Vasilache
sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a,OffsetSizeAndStrideOpInterface b,llvm::function_ref<bool (OpFoldResult,OpFoldResult)> cmp)161ce4f99e7SNicolas Vasilache bool mlir::detail::sameOffsetsSizesAndStrides(
162ce4f99e7SNicolas Vasilache OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
163ce4f99e7SNicolas Vasilache llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
164ce4f99e7SNicolas Vasilache if (a.static_offsets().size() != b.static_offsets().size())
165ce4f99e7SNicolas Vasilache return false;
166ce4f99e7SNicolas Vasilache if (a.static_sizes().size() != b.static_sizes().size())
167ce4f99e7SNicolas Vasilache return false;
168ce4f99e7SNicolas Vasilache if (a.static_strides().size() != b.static_strides().size())
169ce4f99e7SNicolas Vasilache return false;
170ce4f99e7SNicolas Vasilache for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
171ce4f99e7SNicolas Vasilache if (!cmp(std::get<0>(it), std::get<1>(it)))
172ce4f99e7SNicolas Vasilache return false;
173ce4f99e7SNicolas Vasilache for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
174ce4f99e7SNicolas Vasilache if (!cmp(std::get<0>(it), std::get<1>(it)))
175ce4f99e7SNicolas Vasilache return false;
176ce4f99e7SNicolas Vasilache for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
177ce4f99e7SNicolas Vasilache if (!cmp(std::get<0>(it), std::get<1>(it)))
178ce4f99e7SNicolas Vasilache return false;
179ce4f99e7SNicolas Vasilache return true;
180ce4f99e7SNicolas Vasilache }
181d0ee094dSNicolas Vasilache
182b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4>
getMixedOffsets(OffsetSizeAndStrideOpInterface op,ArrayAttr staticOffsets,ValueRange offsets)183b65f21a4SIvan Butygin mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
184b65f21a4SIvan Butygin ArrayAttr staticOffsets, ValueRange offsets) {
185b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4> res;
186b65f21a4SIvan Butygin unsigned numDynamic = 0;
187b65f21a4SIvan Butygin unsigned count = static_cast<unsigned>(staticOffsets.size());
188b65f21a4SIvan Butygin for (unsigned idx = 0; idx < count; ++idx) {
189b65f21a4SIvan Butygin if (op.isDynamicOffset(idx))
190b65f21a4SIvan Butygin res.push_back(offsets[numDynamic++]);
191b65f21a4SIvan Butygin else
192b65f21a4SIvan Butygin res.push_back(staticOffsets[idx]);
193b65f21a4SIvan Butygin }
194b65f21a4SIvan Butygin return res;
195b65f21a4SIvan Butygin }
196b65f21a4SIvan Butygin
197b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4>
getMixedSizes(OffsetSizeAndStrideOpInterface op,ArrayAttr staticSizes,ValueRange sizes)198b65f21a4SIvan Butygin mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
199b65f21a4SIvan Butygin ValueRange sizes) {
200b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4> res;
201b65f21a4SIvan Butygin unsigned numDynamic = 0;
202b65f21a4SIvan Butygin unsigned count = static_cast<unsigned>(staticSizes.size());
203b65f21a4SIvan Butygin for (unsigned idx = 0; idx < count; ++idx) {
204b65f21a4SIvan Butygin if (op.isDynamicSize(idx))
205b65f21a4SIvan Butygin res.push_back(sizes[numDynamic++]);
206b65f21a4SIvan Butygin else
207b65f21a4SIvan Butygin res.push_back(staticSizes[idx]);
208b65f21a4SIvan Butygin }
209b65f21a4SIvan Butygin return res;
210b65f21a4SIvan Butygin }
211b65f21a4SIvan Butygin
212b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4>
getMixedStrides(OffsetSizeAndStrideOpInterface op,ArrayAttr staticStrides,ValueRange strides)213b65f21a4SIvan Butygin mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
214b65f21a4SIvan Butygin ArrayAttr staticStrides, ValueRange strides) {
215b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4> res;
216b65f21a4SIvan Butygin unsigned numDynamic = 0;
217b65f21a4SIvan Butygin unsigned count = static_cast<unsigned>(staticStrides.size());
218b65f21a4SIvan Butygin for (unsigned idx = 0; idx < count; ++idx) {
219b65f21a4SIvan Butygin if (op.isDynamicStride(idx))
220b65f21a4SIvan Butygin res.push_back(strides[numDynamic++]);
221b65f21a4SIvan Butygin else
222b65f21a4SIvan Butygin res.push_back(staticStrides[idx]);
223b65f21a4SIvan Butygin }
224b65f21a4SIvan Butygin return res;
225b65f21a4SIvan Butygin }
22602d29afdSFrederik Gossen
22702d29afdSFrederik Gossen static std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedImpl(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues,const int64_t dynamicValuePlaceholder)22802d29afdSFrederik Gossen decomposeMixedImpl(OpBuilder &b,
22902d29afdSFrederik Gossen const SmallVectorImpl<OpFoldResult> &mixedValues,
23002d29afdSFrederik Gossen const int64_t dynamicValuePlaceholder) {
23102d29afdSFrederik Gossen SmallVector<int64_t> staticValues;
23202d29afdSFrederik Gossen SmallVector<Value> dynamicValues;
23302d29afdSFrederik Gossen for (const auto &it : mixedValues) {
23402d29afdSFrederik Gossen if (it.is<Attribute>()) {
23502d29afdSFrederik Gossen staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
23602d29afdSFrederik Gossen } else {
23702d29afdSFrederik Gossen staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
23802d29afdSFrederik Gossen dynamicValues.push_back(it.get<Value>());
23902d29afdSFrederik Gossen }
24002d29afdSFrederik Gossen }
24102d29afdSFrederik Gossen return {b.getI64ArrayAttr(staticValues), dynamicValues};
24202d29afdSFrederik Gossen }
24302d29afdSFrederik Gossen
decomposeMixedStridesOrOffsets(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues)24402d29afdSFrederik Gossen std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
24502d29afdSFrederik Gossen OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
24602d29afdSFrederik Gossen return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
24702d29afdSFrederik Gossen }
24802d29afdSFrederik Gossen
24902d29afdSFrederik Gossen std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedSizes(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues)25002d29afdSFrederik Gossen mlir::decomposeMixedSizes(OpBuilder &b,
25102d29afdSFrederik Gossen const SmallVectorImpl<OpFoldResult> &mixedValues) {
25202d29afdSFrederik Gossen return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
25302d29afdSFrederik Gossen }
254