1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 static LogicalResult verifyOpWithOffsetSizesAndStridesPart( 21 OffsetSizeAndStrideOpInterface op, StringRef name, 22 unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, 23 llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) { 24 /// Check static and dynamic offsets/sizes/strides breakdown. 25 if (attr.size() != expectedNumElements) 26 return op.emitError("expected ") 27 << expectedNumElements << " " << name << " values"; 28 unsigned expectedNumDynamicEntries = 29 llvm::count_if(attr.getValue(), [&](Attribute attr) { 30 return isDynamic(attr.cast<IntegerAttr>().getInt()); 31 }); 32 if (values.size() != expectedNumDynamicEntries) 33 return op.emitError("expected ") 34 << expectedNumDynamicEntries << " dynamic " << name << " values"; 35 return success(); 36 } 37 38 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { 39 std::array<unsigned, 3> ranks = op.getArrayAttrRanks(); 40 if (failed(verifyOpWithOffsetSizesAndStridesPart( 41 op, "offset", ranks[0], 42 OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), 43 op.static_offsets(), ShapedType::isDynamicStrideOrOffset, 44 op.offsets()))) 45 return failure(); 46 if (failed(verifyOpWithOffsetSizesAndStridesPart( 47 op, "size", ranks[1], 48 OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), 49 op.static_sizes(), ShapedType::isDynamic, op.sizes()))) 50 return failure(); 51 if (failed(verifyOpWithOffsetSizesAndStridesPart( 52 op, "stride", ranks[2], 53 OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), 54 op.static_strides(), ShapedType::isDynamicStrideOrOffset, 55 op.strides()))) 56 return failure(); 57 return success(); 58 } 59 60 /// Print a list with either (1) the static integer value in `arrayAttr` if 61 /// `isDynamic` evaluates to false or (2) the next value otherwise. 62 /// This allows idiomatic printing of mixed value and integer attributes in a 63 /// list. E.g. `[%arg0, 7, 42, %arg42]`. 64 static void 65 printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, 66 ArrayAttr arrayAttr, 67 llvm::function_ref<bool(int64_t)> isDynamic) { 68 p << '['; 69 unsigned idx = 0; 70 llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { 71 int64_t val = a.cast<IntegerAttr>().getInt(); 72 if (isDynamic(val)) 73 p << values[idx++]; 74 else 75 p << val; 76 }); 77 p << ']'; 78 } 79 80 void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p, 81 OffsetSizeAndStrideOpInterface op, 82 StringRef offsetPrefix, 83 StringRef sizePrefix, 84 StringRef stridePrefix, 85 ArrayRef<StringRef> elidedAttrs) { 86 p << offsetPrefix; 87 printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), 88 ShapedType::isDynamicStrideOrOffset); 89 p << sizePrefix; 90 printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), 91 ShapedType::isDynamic); 92 p << stridePrefix; 93 printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), 94 ShapedType::isDynamicStrideOrOffset); 95 p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); 96 } 97 98 /// Parse a mixed list with either (1) static integer values or (2) SSA values. 99 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` 100 /// encode the position of SSA values. Add the parsed SSA values to `ssa` 101 /// in-order. 102 // 103 /// E.g. after parsing "[%arg0, 7, 42, %arg42]": 104 /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" 105 /// 2. `ssa` is filled with "[%arg0, %arg1]". 106 static ParseResult 107 parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, 108 StringRef attrName, int64_t dynVal, 109 SmallVectorImpl<OpAsmParser::OperandType> &ssa) { 110 if (failed(parser.parseLSquare())) 111 return failure(); 112 // 0-D. 113 if (succeeded(parser.parseOptionalRSquare())) { 114 result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); 115 return success(); 116 } 117 118 SmallVector<int64_t, 4> attrVals; 119 while (true) { 120 OpAsmParser::OperandType operand; 121 auto res = parser.parseOptionalOperand(operand); 122 if (res.hasValue() && succeeded(res.getValue())) { 123 ssa.push_back(operand); 124 attrVals.push_back(dynVal); 125 } else { 126 IntegerAttr attr; 127 if (failed(parser.parseAttribute<IntegerAttr>(attr))) 128 return parser.emitError(parser.getNameLoc()) 129 << "expected SSA value or integer"; 130 attrVals.push_back(attr.getInt()); 131 } 132 133 if (succeeded(parser.parseOptionalComma())) 134 continue; 135 if (failed(parser.parseRSquare())) 136 return failure(); 137 break; 138 } 139 140 auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); 141 result.addAttribute(attrName, arrayAttr); 142 return success(); 143 } 144 145 ParseResult mlir::parseOffsetsSizesAndStrides( 146 OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes, 147 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 148 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 149 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 150 return parseOffsetsSizesAndStrides( 151 parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix, 152 parseOptionalSizePrefix, parseOptionalStridePrefix); 153 } 154 155 ParseResult mlir::parseOffsetsSizesAndStrides( 156 OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes, 157 llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)> 158 preResolutionFn, 159 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 160 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 161 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 162 SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo; 163 auto indexType = parser.getBuilder().getIndexType(); 164 if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) || 165 parseListOfOperandsOrIntegers( 166 parser, result, 167 OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), 168 ShapedType::kDynamicStrideOrOffset, offsetsInfo) || 169 (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) || 170 parseListOfOperandsOrIntegers( 171 parser, result, 172 OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), 173 ShapedType::kDynamicSize, sizesInfo) || 174 (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) || 175 parseListOfOperandsOrIntegers( 176 parser, result, 177 OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), 178 ShapedType::kDynamicStrideOrOffset, stridesInfo)) 179 return failure(); 180 // Add segment sizes to result 181 SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), 182 segmentSizes.end()); 183 segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()), 184 static_cast<int>(sizesInfo.size()), 185 static_cast<int>(stridesInfo.size())}); 186 result.addAttribute( 187 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(), 188 parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); 189 return failure( 190 (preResolutionFn && preResolutionFn(parser, result)) || 191 parser.resolveOperands(offsetsInfo, indexType, result.operands) || 192 parser.resolveOperands(sizesInfo, indexType, result.operands) || 193 parser.resolveOperands(stridesInfo, indexType, result.operands)); 194 } 195