1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
21     OffsetSizeAndStrideOpInterface op, StringRef name,
22     unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
23     llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
24   /// Check static and dynamic offsets/sizes/strides breakdown.
25   if (attr.size() != expectedNumElements)
26     return op.emitError("expected ")
27            << expectedNumElements << " " << name << " values";
28   unsigned expectedNumDynamicEntries =
29       llvm::count_if(attr.getValue(), [&](Attribute attr) {
30         return isDynamic(attr.cast<IntegerAttr>().getInt());
31       });
32   if (values.size() != expectedNumDynamicEntries)
33     return op.emitError("expected ")
34            << expectedNumDynamicEntries << " dynamic " << name << " values";
35   return success();
36 }
37 
38 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
39   std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
40   if (failed(verifyOpWithOffsetSizesAndStridesPart(
41           op, "offset", ranks[0],
42           OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
43           op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
44           op.offsets())))
45     return failure();
46   if (failed(verifyOpWithOffsetSizesAndStridesPart(
47           op, "size", ranks[1],
48           OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
49           op.static_sizes(), ShapedType::isDynamic, op.sizes())))
50     return failure();
51   if (failed(verifyOpWithOffsetSizesAndStridesPart(
52           op, "stride", ranks[2],
53           OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
54           op.static_strides(), ShapedType::isDynamicStrideOrOffset,
55           op.strides())))
56     return failure();
57   return success();
58 }
59 
60 /// Parse a mixed list with either (1) static integer values or (2) SSA values.
61 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
62 /// encode the position of SSA values. Add the parsed SSA values to `ssa`
63 /// in-order.
64 //
65 /// E.g. after parsing "[%arg0, 7, 42, %arg42]":
66 ///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
67 ///   2. `ssa` is filled with "[%arg0, %arg1]".
68 static ParseResult
69 parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
70                               StringRef attrName, int64_t dynVal,
71                               SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
72   if (failed(parser.parseLSquare()))
73     return failure();
74   // 0-D.
75   if (succeeded(parser.parseOptionalRSquare())) {
76     result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
77     return success();
78   }
79 
80   SmallVector<int64_t, 4> attrVals;
81   while (true) {
82     OpAsmParser::OperandType operand;
83     auto res = parser.parseOptionalOperand(operand);
84     if (res.hasValue() && succeeded(res.getValue())) {
85       ssa.push_back(operand);
86       attrVals.push_back(dynVal);
87     } else {
88       IntegerAttr attr;
89       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
90         return parser.emitError(parser.getNameLoc())
91                << "expected SSA value or integer";
92       attrVals.push_back(attr.getInt());
93     }
94 
95     if (succeeded(parser.parseOptionalComma()))
96       continue;
97     if (failed(parser.parseRSquare()))
98       return failure();
99     break;
100   }
101 
102   auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
103   result.addAttribute(attrName, arrayAttr);
104   return success();
105 }
106 
107 ParseResult mlir::parseOffsetsSizesAndStrides(
108     OpAsmParser &parser,
109     OperationState &result,
110     ArrayRef<int> segmentSizes,
111     llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
112         preResolutionFn,
113     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
114     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
115     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
116   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
117   auto indexType = parser.getBuilder().getIndexType();
118   if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
119       parseListOfOperandsOrIntegers(
120           parser, result,
121           OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
122           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
123       (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
124       parseListOfOperandsOrIntegers(
125           parser, result,
126           OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
127           ShapedType::kDynamicSize, sizesInfo) ||
128       (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
129       parseListOfOperandsOrIntegers(
130           parser, result,
131           OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
132           ShapedType::kDynamicStrideOrOffset, stridesInfo))
133     return failure();
134   // Add segment sizes to result
135   SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), segmentSizes.end());
136   segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
137                                       static_cast<int>(sizesInfo.size()),
138                                       static_cast<int>(stridesInfo.size())});
139   auto b = parser.getBuilder();
140   result.addAttribute(
141       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
142       b.getI32VectorAttr(segmentSizesFinal));
143   return failure(
144       (preResolutionFn && preResolutionFn(parser, result)) ||
145       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
146       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
147       parser.resolveOperands(stridesInfo, indexType, result.operands));
148 }
149