1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers( 21 Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, 22 ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) { 23 /// Check static and dynamic offsets/sizes/strides breakdown. 24 if (attr.size() != expectedNumElements) 25 return op->emitError("expected ") 26 << expectedNumElements << " " << name << " values"; 27 unsigned expectedNumDynamicEntries = 28 llvm::count_if(attr.getValue(), [&](Attribute attr) { 29 return isDynamic(attr.cast<IntegerAttr>().getInt()); 30 }); 31 if (values.size() != expectedNumDynamicEntries) 32 return op->emitError("expected ") 33 << expectedNumDynamicEntries << " dynamic " << name << " values"; 34 return success(); 35 } 36 37 LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { 38 std::array<unsigned, 3> ranks = op.getArrayAttrRanks(); 39 if (failed(verifyListOfOperandsOrIntegers( 40 op, "offset", ranks[0], op.static_offsets(), op.offsets(), 41 ShapedType::isDynamicStrideOrOffset))) 42 return failure(); 43 if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1], 44 op.static_sizes(), op.sizes(), 45 ShapedType::isDynamic))) 46 return failure(); 47 if (failed(verifyListOfOperandsOrIntegers( 48 op, "stride", ranks[2], op.static_strides(), op.strides(), 49 ShapedType::isDynamicStrideOrOffset))) 50 return failure(); 51 return success(); 52 } 53 54 void mlir::printListOfOperandsOrIntegers( 55 OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, 56 llvm::function_ref<bool(int64_t)> isDynamic) { 57 p << '['; 58 unsigned idx = 0; 59 llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { 60 int64_t val = a.cast<IntegerAttr>().getInt(); 61 if (isDynamic(val)) 62 p << values[idx++]; 63 else 64 p << val; 65 }); 66 p << ']'; 67 } 68 69 void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p, 70 OffsetSizeAndStrideOpInterface op, 71 StringRef offsetPrefix, 72 StringRef sizePrefix, 73 StringRef stridePrefix, 74 ArrayRef<StringRef> elidedAttrs) { 75 p << offsetPrefix; 76 printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), 77 ShapedType::isDynamicStrideOrOffset); 78 p << sizePrefix; 79 printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), 80 ShapedType::isDynamic); 81 p << stridePrefix; 82 printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), 83 ShapedType::isDynamicStrideOrOffset); 84 p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); 85 } 86 87 ParseResult mlir::parseListOfOperandsOrIntegers( 88 OpAsmParser &parser, OperationState &result, StringRef attrName, 89 int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) { 90 if (failed(parser.parseLSquare())) 91 return failure(); 92 // 0-D. 93 if (succeeded(parser.parseOptionalRSquare())) { 94 result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); 95 return success(); 96 } 97 98 SmallVector<int64_t, 4> attrVals; 99 while (true) { 100 OpAsmParser::OperandType operand; 101 auto res = parser.parseOptionalOperand(operand); 102 if (res.hasValue() && succeeded(res.getValue())) { 103 ssa.push_back(operand); 104 attrVals.push_back(dynVal); 105 } else { 106 IntegerAttr attr; 107 if (failed(parser.parseAttribute<IntegerAttr>(attr))) 108 return parser.emitError(parser.getNameLoc()) 109 << "expected SSA value or integer"; 110 attrVals.push_back(attr.getInt()); 111 } 112 113 if (succeeded(parser.parseOptionalComma())) 114 continue; 115 if (failed(parser.parseRSquare())) 116 return failure(); 117 break; 118 } 119 120 auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); 121 result.addAttribute(attrName, arrayAttr); 122 return success(); 123 } 124 125 ParseResult mlir::parseOffsetsSizesAndStrides( 126 OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes, 127 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 128 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 129 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 130 return parseOffsetsSizesAndStrides( 131 parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix, 132 parseOptionalSizePrefix, parseOptionalStridePrefix); 133 } 134 135 ParseResult mlir::parseOffsetsSizesAndStrides( 136 OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes, 137 llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)> 138 preResolutionFn, 139 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix, 140 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix, 141 llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) { 142 SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo; 143 auto indexType = parser.getBuilder().getIndexType(); 144 if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) || 145 parseListOfOperandsOrIntegers( 146 parser, result, 147 OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), 148 ShapedType::kDynamicStrideOrOffset, offsetsInfo) || 149 (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) || 150 parseListOfOperandsOrIntegers( 151 parser, result, 152 OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), 153 ShapedType::kDynamicSize, sizesInfo) || 154 (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) || 155 parseListOfOperandsOrIntegers( 156 parser, result, 157 OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), 158 ShapedType::kDynamicStrideOrOffset, stridesInfo)) 159 return failure(); 160 // Add segment sizes to result 161 SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), 162 segmentSizes.end()); 163 segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()), 164 static_cast<int>(sizesInfo.size()), 165 static_cast<int>(stridesInfo.size())}); 166 result.addAttribute( 167 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(), 168 parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); 169 return failure( 170 (preResolutionFn && preResolutionFn(parser, result)) || 171 parser.resolveOperands(offsetsInfo, indexType, result.operands) || 172 parser.resolveOperands(sizesInfo, indexType, result.operands) || 173 parser.resolveOperands(stridesInfo, indexType, result.operands)); 174 } 175