1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(
21     Operation *op, StringRef name, unsigned maxNumElements, ArrayAttr attr,
22     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23   /// Check static and dynamic offsets/sizes/strides does not overflow type.
24   if (attr.size() > maxNumElements)
25     return op->emitError("expected <= ")
26            << maxNumElements << " " << name << " values";
27   unsigned expectedNumDynamicEntries =
28       llvm::count_if(attr.getValue(), [&](Attribute attr) {
29         return isDynamic(attr.cast<IntegerAttr>().getInt());
30       });
31   if (values.size() != expectedNumDynamicEntries)
32     return op->emitError("expected ")
33            << expectedNumDynamicEntries << " dynamic " << name << " values";
34   return success();
35 }
36 
37 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
38   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
39   // Offsets can come in 2 flavors:
40   //   1. Either single entry (when maxRanks == 1).
41   //   2. Or as an array whose rank must match that of the mixed sizes.
42   // So that the result type is well-formed.
43   if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) &&
44       op.getMixedOffsets().size() != op.getMixedSizes().size())
45     return op->emitError(
46                "expected mixed offsets rank to match mixed sizes rank (")
47            << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
48            << ") so the rank of the result type is well-formed.";
49   // Ranks of mixed sizes and strides must always match so the result type is
50   // well-formed.
51   if (op.getMixedSizes().size() != op.getMixedStrides().size())
52     return op->emitError(
53                "expected mixed sizes rank to match mixed strides rank (")
54            << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
55            << ") so the rank of the result type is well-formed.";
56 
57   if (failed(verifyListOfOperandsOrIntegers(
58           op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
59           ShapedType::isDynamicStrideOrOffset)))
60     return failure();
61   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
62                                             op.static_sizes(), op.sizes(),
63                                             ShapedType::isDynamic)))
64     return failure();
65   if (failed(verifyListOfOperandsOrIntegers(
66           op, "stride", maxRanks[2], op.static_strides(), op.strides(),
67           ShapedType::isDynamicStrideOrOffset)))
68     return failure();
69   return success();
70 }
71 
72 void mlir::printListOfOperandsOrIntegers(
73     OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
74     llvm::function_ref<bool(int64_t)> isDynamic) {
75   p << '[';
76   unsigned idx = 0;
77   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
78     int64_t val = a.cast<IntegerAttr>().getInt();
79     if (isDynamic(val))
80       p << values[idx++];
81     else
82       p << val;
83   });
84   p << ']';
85 }
86 
87 void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
88                                        OffsetSizeAndStrideOpInterface op,
89                                        StringRef offsetPrefix,
90                                        StringRef sizePrefix,
91                                        StringRef stridePrefix,
92                                        ArrayRef<StringRef> elidedAttrs) {
93   p << offsetPrefix;
94   printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
95                                 ShapedType::isDynamicStrideOrOffset);
96   p << sizePrefix;
97   printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
98                                 ShapedType::isDynamic);
99   p << stridePrefix;
100   printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
101                                 ShapedType::isDynamicStrideOrOffset);
102   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
103 }
104 
105 ParseResult mlir::parseListOfOperandsOrIntegers(
106     OpAsmParser &parser, OperationState &result, StringRef attrName,
107     int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
108   if (failed(parser.parseLSquare()))
109     return failure();
110   // 0-D.
111   if (succeeded(parser.parseOptionalRSquare())) {
112     result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
113     return success();
114   }
115 
116   SmallVector<int64_t, 4> attrVals;
117   while (true) {
118     OpAsmParser::OperandType operand;
119     auto res = parser.parseOptionalOperand(operand);
120     if (res.hasValue() && succeeded(res.getValue())) {
121       ssa.push_back(operand);
122       attrVals.push_back(dynVal);
123     } else {
124       IntegerAttr attr;
125       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
126         return parser.emitError(parser.getNameLoc())
127                << "expected SSA value or integer";
128       attrVals.push_back(attr.getInt());
129     }
130 
131     if (succeeded(parser.parseOptionalComma()))
132       continue;
133     if (failed(parser.parseRSquare()))
134       return failure();
135     break;
136   }
137 
138   auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
139   result.addAttribute(attrName, arrayAttr);
140   return success();
141 }
142 
143 ParseResult mlir::parseOffsetsSizesAndStrides(
144     OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
145     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
146     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
147     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
148   return parseOffsetsSizesAndStrides(
149       parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
150       parseOptionalSizePrefix, parseOptionalStridePrefix);
151 }
152 
153 ParseResult mlir::parseOffsetsSizesAndStrides(
154     OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
155     llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
156         preResolutionFn,
157     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
158     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
159     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
160   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
161   auto indexType = parser.getBuilder().getIndexType();
162   if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
163       parseListOfOperandsOrIntegers(
164           parser, result,
165           OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
166           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
167       (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
168       parseListOfOperandsOrIntegers(
169           parser, result,
170           OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
171           ShapedType::kDynamicSize, sizesInfo) ||
172       (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
173       parseListOfOperandsOrIntegers(
174           parser, result,
175           OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
176           ShapedType::kDynamicStrideOrOffset, stridesInfo))
177     return failure();
178   // Add segment sizes to result
179   SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
180                                         segmentSizes.end());
181   segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
182                             static_cast<int>(sizesInfo.size()),
183                             static_cast<int>(stridesInfo.size())});
184   result.addAttribute(
185       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
186       parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
187   return failure(
188       (preResolutionFn && preResolutionFn(parser, result)) ||
189       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
190       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
191       parser.resolveOperands(stridesInfo, indexType, result.operands));
192 }
193