1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ViewLikeInterface.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
20 LogicalResult mlir::verifyListOfOperandsOrIntegers(
21     Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
22     ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23   /// Check static and dynamic offsets/sizes/strides breakdown.
24   if (attr.size() != expectedNumElements)
25     return op->emitError("expected ")
26            << expectedNumElements << " " << name << " values";
27   unsigned expectedNumDynamicEntries =
28       llvm::count_if(attr.getValue(), [&](Attribute attr) {
29         return isDynamic(attr.cast<IntegerAttr>().getInt());
30       });
31   if (values.size() != expectedNumDynamicEntries)
32     return op->emitError("expected ")
33            << expectedNumDynamicEntries << " dynamic " << name << " values";
34   return success();
35 }
36 
37 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
38   std::array<unsigned, 3> ranks = op.getArrayAttrRanks();
39   if (failed(verifyListOfOperandsOrIntegers(
40           op, "offset", ranks[0], op.static_offsets(), op.offsets(),
41           ShapedType::isDynamicStrideOrOffset)))
42     return failure();
43   if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1],
44                                             op.static_sizes(), op.sizes(),
45                                             ShapedType::isDynamic)))
46     return failure();
47   if (failed(verifyListOfOperandsOrIntegers(
48           op, "stride", ranks[2], op.static_strides(), op.strides(),
49           ShapedType::isDynamicStrideOrOffset)))
50     return failure();
51   return success();
52 }
53 
54 void mlir::printListOfOperandsOrIntegers(
55     OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
56     llvm::function_ref<bool(int64_t)> isDynamic) {
57   p << '[';
58   unsigned idx = 0;
59   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
60     int64_t val = a.cast<IntegerAttr>().getInt();
61     if (isDynamic(val))
62       p << values[idx++];
63     else
64       p << val;
65   });
66   p << ']';
67 }
68 
69 void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
70                                        OffsetSizeAndStrideOpInterface op,
71                                        StringRef offsetPrefix,
72                                        StringRef sizePrefix,
73                                        StringRef stridePrefix,
74                                        ArrayRef<StringRef> elidedAttrs) {
75   p << offsetPrefix;
76   printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
77                                 ShapedType::isDynamicStrideOrOffset);
78   p << sizePrefix;
79   printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
80                                 ShapedType::isDynamic);
81   p << stridePrefix;
82   printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
83                                 ShapedType::isDynamicStrideOrOffset);
84   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
85 }
86 
87 ParseResult mlir::parseListOfOperandsOrIntegers(
88     OpAsmParser &parser, OperationState &result, StringRef attrName,
89     int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
90   if (failed(parser.parseLSquare()))
91     return failure();
92   // 0-D.
93   if (succeeded(parser.parseOptionalRSquare())) {
94     result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
95     return success();
96   }
97 
98   SmallVector<int64_t, 4> attrVals;
99   while (true) {
100     OpAsmParser::OperandType operand;
101     auto res = parser.parseOptionalOperand(operand);
102     if (res.hasValue() && succeeded(res.getValue())) {
103       ssa.push_back(operand);
104       attrVals.push_back(dynVal);
105     } else {
106       IntegerAttr attr;
107       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
108         return parser.emitError(parser.getNameLoc())
109                << "expected SSA value or integer";
110       attrVals.push_back(attr.getInt());
111     }
112 
113     if (succeeded(parser.parseOptionalComma()))
114       continue;
115     if (failed(parser.parseRSquare()))
116       return failure();
117     break;
118   }
119 
120   auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
121   result.addAttribute(attrName, arrayAttr);
122   return success();
123 }
124 
125 ParseResult mlir::parseOffsetsSizesAndStrides(
126     OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
127     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
128     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
129     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
130   return parseOffsetsSizesAndStrides(
131       parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
132       parseOptionalSizePrefix, parseOptionalStridePrefix);
133 }
134 
135 ParseResult mlir::parseOffsetsSizesAndStrides(
136     OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
137     llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
138         preResolutionFn,
139     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
140     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
141     llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
142   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
143   auto indexType = parser.getBuilder().getIndexType();
144   if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
145       parseListOfOperandsOrIntegers(
146           parser, result,
147           OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
148           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
149       (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
150       parseListOfOperandsOrIntegers(
151           parser, result,
152           OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
153           ShapedType::kDynamicSize, sizesInfo) ||
154       (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
155       parseListOfOperandsOrIntegers(
156           parser, result,
157           OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
158           ShapedType::kDynamicStrideOrOffset, stridesInfo))
159     return failure();
160   // Add segment sizes to result
161   SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
162                                         segmentSizes.end());
163   segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
164                             static_cast<int>(sizesInfo.size()),
165                             static_cast<int>(stridesInfo.size())});
166   result.addAttribute(
167       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
168       parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
169   return failure(
170       (preResolutionFn && preResolutionFn(parser, result)) ||
171       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
172       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
173       parser.resolveOperands(stridesInfo, indexType, result.operands));
174 }
175