1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 static LogicalResult verifyOpWithOffsetSizesAndStridesPart( 21 OffsetSizeAndStrideOpInterface op, StringRef name, 22 unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, 23 llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) { 24 /// Check static and dynamic offsets/sizes/strides breakdown. 25 if (attr.size() != expectedNumElements) 26 return op.emitError("expected ") 27 << expectedNumElements << " " << name << " values"; 28 unsigned expectedNumDynamicEntries = 29 llvm::count_if(attr.getValue(), [&](Attribute attr) { 30 return isDynamic(attr.cast<IntegerAttr>().getInt()); 31 }); 32 if (values.size() != expectedNumDynamicEntries) 33 return op.emitError("expected ") 34 << expectedNumDynamicEntries << " dynamic " << name << " values"; 35 return success(); 36 } 37 38 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { 39 std::array<unsigned, 3> ranks = op.getArrayAttrRanks(); 40 if (failed(verifyOpWithOffsetSizesAndStridesPart( 41 op, "offset", ranks[0], 42 OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), 43 op.static_offsets(), ShapedType::isDynamicStrideOrOffset, 44 op.offsets()))) 45 return failure(); 46 if (failed(verifyOpWithOffsetSizesAndStridesPart( 47 op, "size", ranks[1], 48 OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), 49 op.static_sizes(), ShapedType::isDynamic, op.sizes()))) 50 return failure(); 51 if (failed(verifyOpWithOffsetSizesAndStridesPart( 52 op, "stride", ranks[2], 53 OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), 54 op.static_strides(), ShapedType::isDynamicStrideOrOffset, 55 op.strides()))) 56 return failure(); 57 return success(); 58 } 59 60 /// Parse a mixed list with either (1) static integer values or (2) SSA values. 61 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` 62 /// encode the position of SSA values. Add the parsed SSA values to `ssa` 63 /// in-order. 64 // 65 /// E.g. after parsing "[%arg0, 7, 42, %arg42]": 66 /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" 67 /// 2. `ssa` is filled with "[%arg0, %arg1]". 68 static ParseResult 69 parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, 70 StringRef attrName, int64_t dynVal, 71 SmallVectorImpl<OpAsmParser::OperandType> &ssa) { 72 if (failed(parser.parseLSquare())) 73 return failure(); 74 // 0-D. 75 if (succeeded(parser.parseOptionalRSquare())) { 76 result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); 77 return success(); 78 } 79 80 SmallVector<int64_t, 4> attrVals; 81 while (true) { 82 OpAsmParser::OperandType operand; 83 auto res = parser.parseOptionalOperand(operand); 84 if (res.hasValue() && succeeded(res.getValue())) { 85 ssa.push_back(operand); 86 attrVals.push_back(dynVal); 87 } else { 88 IntegerAttr attr; 89 if (failed(parser.parseAttribute<IntegerAttr>(attr))) 90 return parser.emitError(parser.getNameLoc()) 91 << "expected SSA value or integer"; 92 attrVals.push_back(attr.getInt()); 93 } 94 95 if (succeeded(parser.parseOptionalComma())) 96 continue; 97 if (failed(parser.parseRSquare())) 98 return failure(); 99 break; 100 } 101 102 auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); 103 result.addAttribute(attrName, arrayAttr); 104 return success(); 105 } 106 107 ParseResult mlir::parseOffsetsSizesAndStrides( 108 OpAsmParser &parser, 109 OperationState &result, 110 ArrayRef<int> segmentSizes, 111 llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)> 112 preResolutionFn, 113 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 114 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 115 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 116 SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo; 117 auto indexType = parser.getBuilder().getIndexType(); 118 if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) || 119 parseListOfOperandsOrIntegers( 120 parser, result, 121 OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), 122 ShapedType::kDynamicStrideOrOffset, offsetsInfo) || 123 (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) || 124 parseListOfOperandsOrIntegers( 125 parser, result, 126 OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), 127 ShapedType::kDynamicSize, sizesInfo) || 128 (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) || 129 parseListOfOperandsOrIntegers( 130 parser, result, 131 OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), 132 ShapedType::kDynamicStrideOrOffset, stridesInfo)) 133 return failure(); 134 // Add segment sizes to result 135 SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), segmentSizes.end()); 136 segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()), 137 static_cast<int>(sizesInfo.size()), 138 static_cast<int>(stridesInfo.size())}); 139 auto b = parser.getBuilder(); 140 result.addAttribute( 141 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(), 142 b.getI32VectorAttr(segmentSizesFinal)); 143 return failure( 144 (preResolutionFn && preResolutionFn(parser, result)) || 145 parser.resolveOperands(offsetsInfo, indexType, result.operands) || 146 parser.resolveOperands(sizesInfo, indexType, result.operands) || 147 parser.resolveOperands(stridesInfo, indexType, result.operands)); 148 } 149