12458cd27SLei Zhang //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
22458cd27SLei Zhang //
32458cd27SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42458cd27SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52458cd27SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62458cd27SLei Zhang //
72458cd27SLei Zhang //===----------------------------------------------------------------------===//
82458cd27SLei Zhang 
92458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.h"
102458cd27SLei Zhang 
112458cd27SLei Zhang using namespace mlir;
122458cd27SLei Zhang 
132458cd27SLei Zhang //===----------------------------------------------------------------------===//
142458cd27SLei Zhang // ViewLike Interfaces
152458cd27SLei Zhang //===----------------------------------------------------------------------===//
162458cd27SLei Zhang 
172458cd27SLei Zhang /// Include the definitions of the loop-like interfaces.
182458cd27SLei Zhang #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19a8de412fSNicolas Vasilache 
verifyListOfOperandsOrIntegers(Operation * op,StringRef name,unsigned numElements,ArrayAttr attr,ValueRange values,llvm::function_ref<bool (int64_t)> isDynamic)20118a7156SMaheshRavishankar LogicalResult mlir::verifyListOfOperandsOrIntegers(
217df7586aSMaheshRavishankar     Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
22118a7156SMaheshRavishankar     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
235133673dSNicolas Vasilache   /// Check static and dynamic offsets/sizes/strides does not overflow type.
247df7586aSMaheshRavishankar   if (attr.size() != numElements)
257df7586aSMaheshRavishankar     return op->emitError("expected ")
267df7586aSMaheshRavishankar            << numElements << " " << name << " values";
27a8de412fSNicolas Vasilache   unsigned expectedNumDynamicEntries =
28a8de412fSNicolas Vasilache       llvm::count_if(attr.getValue(), [&](Attribute attr) {
29a8de412fSNicolas Vasilache         return isDynamic(attr.cast<IntegerAttr>().getInt());
30a8de412fSNicolas Vasilache       });
31a8de412fSNicolas Vasilache   if (values.size() != expectedNumDynamicEntries)
32118a7156SMaheshRavishankar     return op->emitError("expected ")
33a8de412fSNicolas Vasilache            << expectedNumDynamicEntries << " dynamic " << name << " values";
34a8de412fSNicolas Vasilache   return success();
35a8de412fSNicolas Vasilache }
36a8de412fSNicolas Vasilache 
3762851ea7SUday Bondhugula LogicalResult
verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)3862851ea7SUday Bondhugula mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
395133673dSNicolas Vasilache   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
405133673dSNicolas Vasilache   // Offsets can come in 2 flavors:
415133673dSNicolas Vasilache   //   1. Either single entry (when maxRanks == 1).
425133673dSNicolas Vasilache   //   2. Or as an array whose rank must match that of the mixed sizes.
435133673dSNicolas Vasilache   // So that the result type is well-formed.
44*84124ff8SMehdi Amini   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
455133673dSNicolas Vasilache       op.getMixedOffsets().size() != op.getMixedSizes().size())
465133673dSNicolas Vasilache     return op->emitError(
475133673dSNicolas Vasilache                "expected mixed offsets rank to match mixed sizes rank (")
485133673dSNicolas Vasilache            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
495133673dSNicolas Vasilache            << ") so the rank of the result type is well-formed.";
505133673dSNicolas Vasilache   // Ranks of mixed sizes and strides must always match so the result type is
515133673dSNicolas Vasilache   // well-formed.
525133673dSNicolas Vasilache   if (op.getMixedSizes().size() != op.getMixedStrides().size())
535133673dSNicolas Vasilache     return op->emitError(
545133673dSNicolas Vasilache                "expected mixed sizes rank to match mixed strides rank (")
555133673dSNicolas Vasilache            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
565133673dSNicolas Vasilache            << ") so the rank of the result type is well-formed.";
575133673dSNicolas Vasilache 
58118a7156SMaheshRavishankar   if (failed(verifyListOfOperandsOrIntegers(
595133673dSNicolas Vasilache           op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
60118a7156SMaheshRavishankar           ShapedType::isDynamicStrideOrOffset)))
61a8de412fSNicolas Vasilache     return failure();
625133673dSNicolas Vasilache   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
63118a7156SMaheshRavishankar                                             op.static_sizes(), op.sizes(),
64118a7156SMaheshRavishankar                                             ShapedType::isDynamic)))
65a8de412fSNicolas Vasilache     return failure();
66118a7156SMaheshRavishankar   if (failed(verifyListOfOperandsOrIntegers(
675133673dSNicolas Vasilache           op, "stride", maxRanks[2], op.static_strides(), op.strides(),
68118a7156SMaheshRavishankar           ShapedType::isDynamicStrideOrOffset)))
69a8de412fSNicolas Vasilache     return failure();
70a8de412fSNicolas Vasilache   return success();
71a8de412fSNicolas Vasilache }
72b6c71c13SNicolas Vasilache 
73342d4662SMaheshRavishankar template <int64_t dynVal>
printOperandsOrIntegersListImpl(OpAsmPrinter & p,ValueRange values,ArrayAttr arrayAttr)74342d4662SMaheshRavishankar static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
75342d4662SMaheshRavishankar                                             ArrayAttr arrayAttr) {
76c2470810SNicolas Vasilache   p << '[';
77342d4662SMaheshRavishankar   if (arrayAttr.empty()) {
78342d4662SMaheshRavishankar     p << "]";
79342d4662SMaheshRavishankar     return;
80342d4662SMaheshRavishankar   }
81c2470810SNicolas Vasilache   unsigned idx = 0;
82c2470810SNicolas Vasilache   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
83c2470810SNicolas Vasilache     int64_t val = a.cast<IntegerAttr>().getInt();
84342d4662SMaheshRavishankar     if (val == dynVal)
85c2470810SNicolas Vasilache       p << values[idx++];
86c2470810SNicolas Vasilache     else
87c2470810SNicolas Vasilache       p << val;
88c2470810SNicolas Vasilache   });
89c2470810SNicolas Vasilache   p << ']';
90c2470810SNicolas Vasilache }
91c2470810SNicolas Vasilache 
printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter & p,Operation * op,OperandRange values,ArrayAttr integers)92342d4662SMaheshRavishankar void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
93342d4662SMaheshRavishankar                                                        Operation *op,
94342d4662SMaheshRavishankar                                                        OperandRange values,
95342d4662SMaheshRavishankar                                                        ArrayAttr integers) {
96342d4662SMaheshRavishankar   return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
97342d4662SMaheshRavishankar       p, values, integers);
98c2470810SNicolas Vasilache }
99c2470810SNicolas Vasilache 
printOperandsOrIntegersSizesList(OpAsmPrinter & p,Operation * op,OperandRange values,ArrayAttr integers)100342d4662SMaheshRavishankar void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
101342d4662SMaheshRavishankar                                             OperandRange values,
102342d4662SMaheshRavishankar                                             ArrayAttr integers) {
103342d4662SMaheshRavishankar   return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
104342d4662SMaheshRavishankar                                                                    integers);
105342d4662SMaheshRavishankar }
106342d4662SMaheshRavishankar 
107342d4662SMaheshRavishankar template <int64_t dynVal>
parseOperandsOrIntegersImpl(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)108e13d23bcSMarkus Böck static ParseResult parseOperandsOrIntegersImpl(
109e13d23bcSMarkus Böck     OpAsmParser &parser,
110e13d23bcSMarkus Böck     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
111342d4662SMaheshRavishankar     ArrayAttr &integers) {
112b6c71c13SNicolas Vasilache   if (failed(parser.parseLSquare()))
113b6c71c13SNicolas Vasilache     return failure();
114b6c71c13SNicolas Vasilache   // 0-D.
115b6c71c13SNicolas Vasilache   if (succeeded(parser.parseOptionalRSquare())) {
116342d4662SMaheshRavishankar     integers = parser.getBuilder().getArrayAttr({});
117b6c71c13SNicolas Vasilache     return success();
118b6c71c13SNicolas Vasilache   }
119b6c71c13SNicolas Vasilache 
120b6c71c13SNicolas Vasilache   SmallVector<int64_t, 4> attrVals;
121b6c71c13SNicolas Vasilache   while (true) {
122e13d23bcSMarkus Böck     OpAsmParser::UnresolvedOperand operand;
123b6c71c13SNicolas Vasilache     auto res = parser.parseOptionalOperand(operand);
1243b7c3a65SKazu Hirata     if (res.hasValue() && succeeded(res.getValue())) {
125342d4662SMaheshRavishankar       values.push_back(operand);
126b6c71c13SNicolas Vasilache       attrVals.push_back(dynVal);
127b6c71c13SNicolas Vasilache     } else {
128b6c71c13SNicolas Vasilache       IntegerAttr attr;
129b6c71c13SNicolas Vasilache       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
130b6c71c13SNicolas Vasilache         return parser.emitError(parser.getNameLoc())
131b6c71c13SNicolas Vasilache                << "expected SSA value or integer";
132b6c71c13SNicolas Vasilache       attrVals.push_back(attr.getInt());
133b6c71c13SNicolas Vasilache     }
134b6c71c13SNicolas Vasilache 
135b6c71c13SNicolas Vasilache     if (succeeded(parser.parseOptionalComma()))
136b6c71c13SNicolas Vasilache       continue;
137b6c71c13SNicolas Vasilache     if (failed(parser.parseRSquare()))
138b6c71c13SNicolas Vasilache       return failure();
139b6c71c13SNicolas Vasilache     break;
140b6c71c13SNicolas Vasilache   }
141342d4662SMaheshRavishankar   integers = parser.getBuilder().getI64ArrayAttr(attrVals);
142b6c71c13SNicolas Vasilache   return success();
143b6c71c13SNicolas Vasilache }
144b6c71c13SNicolas Vasilache 
parseOperandsOrIntegersOffsetsOrStridesList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)145342d4662SMaheshRavishankar ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
146e13d23bcSMarkus Böck     OpAsmParser &parser,
147e13d23bcSMarkus Böck     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
148342d4662SMaheshRavishankar     ArrayAttr &integers) {
149342d4662SMaheshRavishankar   return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
150342d4662SMaheshRavishankar       parser, values, integers);
151c2470810SNicolas Vasilache }
152c2470810SNicolas Vasilache 
parseOperandsOrIntegersSizesList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & values,ArrayAttr & integers)153342d4662SMaheshRavishankar ParseResult mlir::parseOperandsOrIntegersSizesList(
154e13d23bcSMarkus Böck     OpAsmParser &parser,
155e13d23bcSMarkus Böck     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
156342d4662SMaheshRavishankar     ArrayAttr &integers) {
157342d4662SMaheshRavishankar   return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
158342d4662SMaheshRavishankar                                                                integers);
159b6c71c13SNicolas Vasilache }
160ce4f99e7SNicolas Vasilache 
sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a,OffsetSizeAndStrideOpInterface b,llvm::function_ref<bool (OpFoldResult,OpFoldResult)> cmp)161ce4f99e7SNicolas Vasilache bool mlir::detail::sameOffsetsSizesAndStrides(
162ce4f99e7SNicolas Vasilache     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
163ce4f99e7SNicolas Vasilache     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
164ce4f99e7SNicolas Vasilache   if (a.static_offsets().size() != b.static_offsets().size())
165ce4f99e7SNicolas Vasilache     return false;
166ce4f99e7SNicolas Vasilache   if (a.static_sizes().size() != b.static_sizes().size())
167ce4f99e7SNicolas Vasilache     return false;
168ce4f99e7SNicolas Vasilache   if (a.static_strides().size() != b.static_strides().size())
169ce4f99e7SNicolas Vasilache     return false;
170ce4f99e7SNicolas Vasilache   for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
171ce4f99e7SNicolas Vasilache     if (!cmp(std::get<0>(it), std::get<1>(it)))
172ce4f99e7SNicolas Vasilache       return false;
173ce4f99e7SNicolas Vasilache   for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
174ce4f99e7SNicolas Vasilache     if (!cmp(std::get<0>(it), std::get<1>(it)))
175ce4f99e7SNicolas Vasilache       return false;
176ce4f99e7SNicolas Vasilache   for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
177ce4f99e7SNicolas Vasilache     if (!cmp(std::get<0>(it), std::get<1>(it)))
178ce4f99e7SNicolas Vasilache       return false;
179ce4f99e7SNicolas Vasilache   return true;
180ce4f99e7SNicolas Vasilache }
181d0ee094dSNicolas Vasilache 
182b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4>
getMixedOffsets(OffsetSizeAndStrideOpInterface op,ArrayAttr staticOffsets,ValueRange offsets)183b65f21a4SIvan Butygin mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
184b65f21a4SIvan Butygin                       ArrayAttr staticOffsets, ValueRange offsets) {
185b65f21a4SIvan Butygin   SmallVector<OpFoldResult, 4> res;
186b65f21a4SIvan Butygin   unsigned numDynamic = 0;
187b65f21a4SIvan Butygin   unsigned count = static_cast<unsigned>(staticOffsets.size());
188b65f21a4SIvan Butygin   for (unsigned idx = 0; idx < count; ++idx) {
189b65f21a4SIvan Butygin     if (op.isDynamicOffset(idx))
190b65f21a4SIvan Butygin       res.push_back(offsets[numDynamic++]);
191b65f21a4SIvan Butygin     else
192b65f21a4SIvan Butygin       res.push_back(staticOffsets[idx]);
193b65f21a4SIvan Butygin   }
194b65f21a4SIvan Butygin   return res;
195b65f21a4SIvan Butygin }
196b65f21a4SIvan Butygin 
197b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4>
getMixedSizes(OffsetSizeAndStrideOpInterface op,ArrayAttr staticSizes,ValueRange sizes)198b65f21a4SIvan Butygin mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
199b65f21a4SIvan Butygin                     ValueRange sizes) {
200b65f21a4SIvan Butygin   SmallVector<OpFoldResult, 4> res;
201b65f21a4SIvan Butygin   unsigned numDynamic = 0;
202b65f21a4SIvan Butygin   unsigned count = static_cast<unsigned>(staticSizes.size());
203b65f21a4SIvan Butygin   for (unsigned idx = 0; idx < count; ++idx) {
204b65f21a4SIvan Butygin     if (op.isDynamicSize(idx))
205b65f21a4SIvan Butygin       res.push_back(sizes[numDynamic++]);
206b65f21a4SIvan Butygin     else
207b65f21a4SIvan Butygin       res.push_back(staticSizes[idx]);
208b65f21a4SIvan Butygin   }
209b65f21a4SIvan Butygin   return res;
210b65f21a4SIvan Butygin }
211b65f21a4SIvan Butygin 
212b65f21a4SIvan Butygin SmallVector<OpFoldResult, 4>
getMixedStrides(OffsetSizeAndStrideOpInterface op,ArrayAttr staticStrides,ValueRange strides)213b65f21a4SIvan Butygin mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
214b65f21a4SIvan Butygin                       ArrayAttr staticStrides, ValueRange strides) {
215b65f21a4SIvan Butygin   SmallVector<OpFoldResult, 4> res;
216b65f21a4SIvan Butygin   unsigned numDynamic = 0;
217b65f21a4SIvan Butygin   unsigned count = static_cast<unsigned>(staticStrides.size());
218b65f21a4SIvan Butygin   for (unsigned idx = 0; idx < count; ++idx) {
219b65f21a4SIvan Butygin     if (op.isDynamicStride(idx))
220b65f21a4SIvan Butygin       res.push_back(strides[numDynamic++]);
221b65f21a4SIvan Butygin     else
222b65f21a4SIvan Butygin       res.push_back(staticStrides[idx]);
223b65f21a4SIvan Butygin   }
224b65f21a4SIvan Butygin   return res;
225b65f21a4SIvan Butygin }
22602d29afdSFrederik Gossen 
22702d29afdSFrederik Gossen static std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedImpl(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues,const int64_t dynamicValuePlaceholder)22802d29afdSFrederik Gossen decomposeMixedImpl(OpBuilder &b,
22902d29afdSFrederik Gossen                    const SmallVectorImpl<OpFoldResult> &mixedValues,
23002d29afdSFrederik Gossen                    const int64_t dynamicValuePlaceholder) {
23102d29afdSFrederik Gossen   SmallVector<int64_t> staticValues;
23202d29afdSFrederik Gossen   SmallVector<Value> dynamicValues;
23302d29afdSFrederik Gossen   for (const auto &it : mixedValues) {
23402d29afdSFrederik Gossen     if (it.is<Attribute>()) {
23502d29afdSFrederik Gossen       staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
23602d29afdSFrederik Gossen     } else {
23702d29afdSFrederik Gossen       staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
23802d29afdSFrederik Gossen       dynamicValues.push_back(it.get<Value>());
23902d29afdSFrederik Gossen     }
24002d29afdSFrederik Gossen   }
24102d29afdSFrederik Gossen   return {b.getI64ArrayAttr(staticValues), dynamicValues};
24202d29afdSFrederik Gossen }
24302d29afdSFrederik Gossen 
decomposeMixedStridesOrOffsets(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues)24402d29afdSFrederik Gossen std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
24502d29afdSFrederik Gossen     OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
24602d29afdSFrederik Gossen   return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
24702d29afdSFrederik Gossen }
24802d29afdSFrederik Gossen 
24902d29afdSFrederik Gossen std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedSizes(OpBuilder & b,const SmallVectorImpl<OpFoldResult> & mixedValues)25002d29afdSFrederik Gossen mlir::decomposeMixedSizes(OpBuilder &b,
25102d29afdSFrederik Gossen                           const SmallVectorImpl<OpFoldResult> &mixedValues) {
25202d29afdSFrederik Gossen   return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
25302d29afdSFrederik Gossen }
254