1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers( 21 Operation *op, StringRef name, unsigned maxNumElements, ArrayAttr attr, 22 ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) { 23 /// Check static and dynamic offsets/sizes/strides does not overflow type. 24 if (attr.size() > maxNumElements) 25 return op->emitError("expected <= ") 26 << maxNumElements << " " << name << " values"; 27 unsigned expectedNumDynamicEntries = 28 llvm::count_if(attr.getValue(), [&](Attribute attr) { 29 return isDynamic(attr.cast<IntegerAttr>().getInt()); 30 }); 31 if (values.size() != expectedNumDynamicEntries) 32 return op->emitError("expected ") 33 << expectedNumDynamicEntries << " dynamic " << name << " values"; 34 return success(); 35 } 36 37 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { 38 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 39 // Offsets can come in 2 flavors: 40 // 1. Either single entry (when maxRanks == 1). 41 // 2. Or as an array whose rank must match that of the mixed sizes. 42 // So that the result type is well-formed. 43 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && 44 op.getMixedOffsets().size() != op.getMixedSizes().size()) 45 return op->emitError( 46 "expected mixed offsets rank to match mixed sizes rank (") 47 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 48 << ") so the rank of the result type is well-formed."; 49 // Ranks of mixed sizes and strides must always match so the result type is 50 // well-formed. 51 if (op.getMixedSizes().size() != op.getMixedStrides().size()) 52 return op->emitError( 53 "expected mixed sizes rank to match mixed strides rank (") 54 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 55 << ") so the rank of the result type is well-formed."; 56 57 if (failed(verifyListOfOperandsOrIntegers( 58 op, "offset", maxRanks[0], op.static_offsets(), op.offsets(), 59 ShapedType::isDynamicStrideOrOffset))) 60 return failure(); 61 if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], 62 op.static_sizes(), op.sizes(), 63 ShapedType::isDynamic))) 64 return failure(); 65 if (failed(verifyListOfOperandsOrIntegers( 66 op, "stride", maxRanks[2], op.static_strides(), op.strides(), 67 ShapedType::isDynamicStrideOrOffset))) 68 return failure(); 69 return success(); 70 } 71 72 void mlir::printListOfOperandsOrIntegers( 73 OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, 74 llvm::function_ref<bool(int64_t)> isDynamic) { 75 p << '['; 76 unsigned idx = 0; 77 llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { 78 int64_t val = a.cast<IntegerAttr>().getInt(); 79 if (isDynamic(val)) 80 p << values[idx++]; 81 else 82 p << val; 83 }); 84 p << ']'; 85 } 86 87 void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p, 88 OffsetSizeAndStrideOpInterface op, 89 StringRef offsetPrefix, 90 StringRef sizePrefix, 91 StringRef stridePrefix, 92 ArrayRef<StringRef> elidedAttrs) { 93 p << offsetPrefix; 94 printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), 95 ShapedType::isDynamicStrideOrOffset); 96 p << sizePrefix; 97 printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), 98 ShapedType::isDynamic); 99 p << stridePrefix; 100 printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), 101 ShapedType::isDynamicStrideOrOffset); 102 p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); 103 } 104 105 ParseResult mlir::parseListOfOperandsOrIntegers( 106 OpAsmParser &parser, OperationState &result, StringRef attrName, 107 int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) { 108 if (failed(parser.parseLSquare())) 109 return failure(); 110 // 0-D. 111 if (succeeded(parser.parseOptionalRSquare())) { 112 result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); 113 return success(); 114 } 115 116 SmallVector<int64_t, 4> attrVals; 117 while (true) { 118 OpAsmParser::OperandType operand; 119 auto res = parser.parseOptionalOperand(operand); 120 if (res.hasValue() && succeeded(res.getValue())) { 121 ssa.push_back(operand); 122 attrVals.push_back(dynVal); 123 } else { 124 IntegerAttr attr; 125 if (failed(parser.parseAttribute<IntegerAttr>(attr))) 126 return parser.emitError(parser.getNameLoc()) 127 << "expected SSA value or integer"; 128 attrVals.push_back(attr.getInt()); 129 } 130 131 if (succeeded(parser.parseOptionalComma())) 132 continue; 133 if (failed(parser.parseRSquare())) 134 return failure(); 135 break; 136 } 137 138 auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); 139 result.addAttribute(attrName, arrayAttr); 140 return success(); 141 } 142 143 ParseResult mlir::parseOffsetsSizesAndStrides( 144 OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes, 145 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 146 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 147 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 148 return parseOffsetsSizesAndStrides( 149 parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix, 150 parseOptionalSizePrefix, parseOptionalStridePrefix); 151 } 152 153 ParseResult mlir::parseOffsetsSizesAndStrides( 154 OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes, 155 llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)> 156 preResolutionFn, 157 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 158 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 159 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 160 SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo; 161 auto indexType = parser.getBuilder().getIndexType(); 162 if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) || 163 parseListOfOperandsOrIntegers( 164 parser, result, 165 OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), 166 ShapedType::kDynamicStrideOrOffset, offsetsInfo) || 167 (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) || 168 parseListOfOperandsOrIntegers( 169 parser, result, 170 OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), 171 ShapedType::kDynamicSize, sizesInfo) || 172 (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) || 173 parseListOfOperandsOrIntegers( 174 parser, result, 175 OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), 176 ShapedType::kDynamicStrideOrOffset, stridesInfo)) 177 return failure(); 178 // Add segment sizes to result 179 SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), 180 segmentSizes.end()); 181 segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()), 182 static_cast<int>(sizesInfo.size()), 183 static_cast<int>(stridesInfo.size())}); 184 result.addAttribute( 185 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(), 186 parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); 187 return failure( 188 (preResolutionFn && preResolutionFn(parser, result)) || 189 parser.resolveOperands(offsetsInfo, indexType, result.operands) || 190 parser.resolveOperands(sizesInfo, indexType, result.operands) || 191 parser.resolveOperands(stridesInfo, indexType, result.operands)); 192 } 193