1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/ViewLikeInterface.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // ViewLike Interfaces 15 //===----------------------------------------------------------------------===// 16 17 /// Include the definitions of the loop-like interfaces. 18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" 19 20 LogicalResult mlir::verifyListOfOperandsOrIntegers( 21 Operation *op, StringRef name, unsigned numElements, ArrayAttr attr, 22 ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) { 23 /// Check static and dynamic offsets/sizes/strides does not overflow type. 24 if (attr.size() != numElements) 25 return op->emitError("expected ") 26 << numElements << " " << name << " values"; 27 unsigned expectedNumDynamicEntries = 28 llvm::count_if(attr.getValue(), [&](Attribute attr) { 29 return isDynamic(attr.cast<IntegerAttr>().getInt()); 30 }); 31 if (values.size() != expectedNumDynamicEntries) 32 return op->emitError("expected ") 33 << expectedNumDynamicEntries << " dynamic " << name << " values"; 34 return success(); 35 } 36 37 LogicalResult 38 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { 39 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); 40 // Offsets can come in 2 flavors: 41 // 1. Either single entry (when maxRanks == 1). 42 // 2. Or as an array whose rank must match that of the mixed sizes. 43 // So that the result type is well-formed. 44 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && 45 op.getMixedOffsets().size() != op.getMixedSizes().size()) 46 return op->emitError( 47 "expected mixed offsets rank to match mixed sizes rank (") 48 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() 49 << ") so the rank of the result type is well-formed."; 50 // Ranks of mixed sizes and strides must always match so the result type is 51 // well-formed. 52 if (op.getMixedSizes().size() != op.getMixedStrides().size()) 53 return op->emitError( 54 "expected mixed sizes rank to match mixed strides rank (") 55 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() 56 << ") so the rank of the result type is well-formed."; 57 58 if (failed(verifyListOfOperandsOrIntegers( 59 op, "offset", maxRanks[0], op.static_offsets(), op.offsets(), 60 ShapedType::isDynamicStrideOrOffset))) 61 return failure(); 62 if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], 63 op.static_sizes(), op.sizes(), 64 ShapedType::isDynamic))) 65 return failure(); 66 if (failed(verifyListOfOperandsOrIntegers( 67 op, "stride", maxRanks[2], op.static_strides(), op.strides(), 68 ShapedType::isDynamicStrideOrOffset))) 69 return failure(); 70 return success(); 71 } 72 73 template <int64_t dynVal> 74 static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values, 75 ArrayAttr arrayAttr) { 76 p << '['; 77 if (arrayAttr.empty()) { 78 p << "]"; 79 return; 80 } 81 unsigned idx = 0; 82 llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { 83 int64_t val = a.cast<IntegerAttr>().getInt(); 84 if (val == dynVal) 85 p << values[idx++]; 86 else 87 p << val; 88 }); 89 p << ']'; 90 } 91 92 void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p, 93 Operation *op, 94 OperandRange values, 95 ArrayAttr integers) { 96 return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>( 97 p, values, integers); 98 } 99 100 void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op, 101 OperandRange values, 102 ArrayAttr integers) { 103 return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values, 104 integers); 105 } 106 107 template <int64_t dynVal> 108 static ParseResult 109 parseOperandsOrIntegersImpl(OpAsmParser &parser, 110 SmallVectorImpl<OpAsmParser::OperandType> &values, 111 ArrayAttr &integers) { 112 if (failed(parser.parseLSquare())) 113 return failure(); 114 // 0-D. 115 if (succeeded(parser.parseOptionalRSquare())) { 116 integers = parser.getBuilder().getArrayAttr({}); 117 return success(); 118 } 119 120 SmallVector<int64_t, 4> attrVals; 121 while (true) { 122 OpAsmParser::OperandType operand; 123 auto res = parser.parseOptionalOperand(operand); 124 if (res.hasValue() && succeeded(res.getValue())) { 125 values.push_back(operand); 126 attrVals.push_back(dynVal); 127 } else { 128 IntegerAttr attr; 129 if (failed(parser.parseAttribute<IntegerAttr>(attr))) 130 return parser.emitError(parser.getNameLoc()) 131 << "expected SSA value or integer"; 132 attrVals.push_back(attr.getInt()); 133 } 134 135 if (succeeded(parser.parseOptionalComma())) 136 continue; 137 if (failed(parser.parseRSquare())) 138 return failure(); 139 break; 140 } 141 integers = parser.getBuilder().getI64ArrayAttr(attrVals); 142 return success(); 143 } 144 145 ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList( 146 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values, 147 ArrayAttr &integers) { 148 return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>( 149 parser, values, integers); 150 } 151 152 ParseResult mlir::parseOperandsOrIntegersSizesList( 153 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values, 154 ArrayAttr &integers) { 155 return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values, 156 integers); 157 } 158 159 bool mlir::detail::sameOffsetsSizesAndStrides( 160 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 161 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { 162 if (a.static_offsets().size() != b.static_offsets().size()) 163 return false; 164 if (a.static_sizes().size() != b.static_sizes().size()) 165 return false; 166 if (a.static_strides().size() != b.static_strides().size()) 167 return false; 168 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) 169 if (!cmp(std::get<0>(it), std::get<1>(it))) 170 return false; 171 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) 172 if (!cmp(std::get<0>(it), std::get<1>(it))) 173 return false; 174 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) 175 if (!cmp(std::get<0>(it), std::get<1>(it))) 176 return false; 177 return true; 178 } 179 180 void OffsetSizeAndStrideOpInterface::expandToRank( 181 Value target, SmallVector<OpFoldResult> &offsets, 182 SmallVector<OpFoldResult> &sizes, SmallVector<OpFoldResult> &strides, 183 llvm::function_ref<OpFoldResult(Value, int64_t)> createOrFoldDim) { 184 auto shapedType = target.getType().cast<ShapedType>(); 185 unsigned rank = shapedType.getRank(); 186 assert(offsets.size() == sizes.size() && "mismatched lengths"); 187 assert(offsets.size() == strides.size() && "mismatched lengths"); 188 assert(offsets.size() <= rank && "rank overflow"); 189 MLIRContext *ctx = target.getContext(); 190 Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0)); 191 Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1)); 192 for (unsigned i = offsets.size(); i < rank; ++i) { 193 offsets.push_back(zero); 194 sizes.push_back(createOrFoldDim(target, i)); 195 strides.push_back(one); 196 } 197 } 198 199 SmallVector<OpFoldResult, 4> 200 mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, 201 ArrayAttr staticOffsets, ValueRange offsets) { 202 SmallVector<OpFoldResult, 4> res; 203 unsigned numDynamic = 0; 204 unsigned count = static_cast<unsigned>(staticOffsets.size()); 205 for (unsigned idx = 0; idx < count; ++idx) { 206 if (op.isDynamicOffset(idx)) 207 res.push_back(offsets[numDynamic++]); 208 else 209 res.push_back(staticOffsets[idx]); 210 } 211 return res; 212 } 213 214 SmallVector<OpFoldResult, 4> 215 mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, 216 ValueRange sizes) { 217 SmallVector<OpFoldResult, 4> res; 218 unsigned numDynamic = 0; 219 unsigned count = static_cast<unsigned>(staticSizes.size()); 220 for (unsigned idx = 0; idx < count; ++idx) { 221 if (op.isDynamicSize(idx)) 222 res.push_back(sizes[numDynamic++]); 223 else 224 res.push_back(staticSizes[idx]); 225 } 226 return res; 227 } 228 229 SmallVector<OpFoldResult, 4> 230 mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, 231 ArrayAttr staticStrides, ValueRange strides) { 232 SmallVector<OpFoldResult, 4> res; 233 unsigned numDynamic = 0; 234 unsigned count = static_cast<unsigned>(staticStrides.size()); 235 for (unsigned idx = 0; idx < count; ++idx) { 236 if (op.isDynamicStride(idx)) 237 res.push_back(strides[numDynamic++]); 238 else 239 res.push_back(staticStrides[idx]); 240 } 241 return res; 242 } 243