1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements convenience types for working with super-vectorization 10 // operations, in particular super-vector loads and stores. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/IndexingUtils.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/BuiltinOps.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/DialectImplementation.h" 29 #include "mlir/IR/OpImplementation.h" 30 #include "mlir/IR/PatternMatch.h" 31 #include "mlir/IR/TypeUtilities.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Support/MathExtras.h" 34 #include "llvm/ADT/StringSet.h" 35 #include "llvm/ADT/bit.h" 36 #include <numeric> 37 38 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" 39 // Pull in all enum type and utility function definitions. 40 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" 41 42 using namespace mlir; 43 using namespace mlir::vector; 44 45 /// Helper enum to classify mask value. 46 enum class MaskFormat { 47 AllTrue = 0, 48 AllFalse = 1, 49 Unknown = 2, 50 }; 51 52 /// Helper method to classify a 1-D mask value. Currently, the method 53 /// looks "under the hood" of a constant value with dense attributes 54 /// and a constant mask operation (since the client may be called at 55 /// various stages during progressive lowering). 56 static MaskFormat get1DMaskFormat(Value mask) { 57 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) { 58 // Inspect constant dense values. We count up for bits that 59 // are set, count down for bits that are cleared, and bail 60 // when a mix is detected. 61 if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) { 62 int64_t val = 0; 63 for (bool b : denseElts.getValues<bool>()) 64 if (b && val >= 0) 65 val++; 66 else if (!b && val <= 0) 67 val--; 68 else 69 return MaskFormat::Unknown; 70 if (val > 0) 71 return MaskFormat::AllTrue; 72 if (val < 0) 73 return MaskFormat::AllFalse; 74 } 75 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { 76 // Inspect constant mask index. If the index exceeds the 77 // dimension size, all bits are set. If the index is zero 78 // or less, no bits are set. 79 ArrayAttr masks = m.mask_dim_sizes(); 80 assert(masks.size() == 1); 81 int64_t i = masks[0].cast<IntegerAttr>().getInt(); 82 int64_t u = m.getType().getDimSize(0); 83 if (i >= u) 84 return MaskFormat::AllTrue; 85 if (i <= 0) 86 return MaskFormat::AllFalse; 87 } 88 return MaskFormat::Unknown; 89 } 90 91 // Helper for verifying combining kinds in contractions and reductions. 92 static bool isSupportedCombiningKind(CombiningKind combiningKind, 93 Type elementType) { 94 switch (combiningKind) { 95 case CombiningKind::ADD: 96 case CombiningKind::MUL: 97 return elementType.isIntOrIndexOrFloat(); 98 case CombiningKind::MINUI: 99 case CombiningKind::MINSI: 100 case CombiningKind::MAXUI: 101 case CombiningKind::MAXSI: 102 case CombiningKind::AND: 103 case CombiningKind::OR: 104 case CombiningKind::XOR: 105 return elementType.isIntOrIndex(); 106 case CombiningKind::MINF: 107 case CombiningKind::MAXF: 108 return elementType.isa<FloatType>(); 109 } 110 return false; 111 } 112 113 /// Return true if the last dimension of the MemRefType has unit stride. Also 114 /// return true for memrefs with no strides. 115 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { 116 int64_t offset; 117 SmallVector<int64_t> strides; 118 auto successStrides = getStridesAndOffset(type, strides, offset); 119 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 120 } 121 122 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, 123 VectorType vectorType) { 124 int64_t elementVectorRank = 0; 125 VectorType elementVectorType = 126 shapedType.getElementType().dyn_cast<VectorType>(); 127 if (elementVectorType) 128 elementVectorRank += elementVectorType.getRank(); 129 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>. 130 // TODO: replace once we have 0-d vectors. 131 if (shapedType.getRank() == 0 && 132 vectorType.getShape() == ArrayRef<int64_t>{1}) 133 return AffineMap::get( 134 /*numDims=*/0, /*numSymbols=*/0, 135 getAffineConstantExpr(0, shapedType.getContext())); 136 return AffineMap::getMinorIdentityMap( 137 shapedType.getRank(), vectorType.getRank() - elementVectorRank, 138 shapedType.getContext()); 139 } 140 141 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, 142 vector::TransferReadOp read) { 143 return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && 144 defWrite.indices() == read.indices() && 145 defWrite.getVectorType() == read.getVectorType() && 146 defWrite.permutation_map() == read.permutation_map(); 147 } 148 149 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, 150 vector::TransferWriteOp priorWrite) { 151 return priorWrite.indices() == write.indices() && 152 priorWrite.mask() == write.mask() && 153 priorWrite.getVectorType() == write.getVectorType() && 154 priorWrite.permutation_map() == write.permutation_map(); 155 } 156 157 bool mlir::vector::isDisjointTransferIndices( 158 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { 159 // For simplicity only look at transfer of same type. 160 if (transferA.getVectorType() != transferB.getVectorType()) 161 return false; 162 unsigned rankOffset = transferA.getLeadingShapedRank(); 163 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { 164 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>(); 165 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>(); 166 // If any of the indices are dynamic we cannot prove anything. 167 if (!indexA || !indexB) 168 continue; 169 170 if (i < rankOffset) { 171 // For leading dimensions, if we can prove that index are different we 172 // know we are accessing disjoint slices. 173 if (indexA.getValue().cast<IntegerAttr>().getInt() != 174 indexB.getValue().cast<IntegerAttr>().getInt()) 175 return true; 176 } else { 177 // For this dimension, we slice a part of the memref we need to make sure 178 // the intervals accessed don't overlap. 179 int64_t distance = 180 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() - 181 indexB.getValue().cast<IntegerAttr>().getInt()); 182 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) 183 return true; 184 } 185 } 186 return false; 187 } 188 189 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, 190 VectorTransferOpInterface transferB) { 191 if (transferA.source() != transferB.source()) 192 return false; 193 return isDisjointTransferIndices(transferA, transferB); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // CombiningKindAttr 198 //===----------------------------------------------------------------------===// 199 200 namespace mlir { 201 namespace vector { 202 namespace detail { 203 struct BitmaskEnumStorage : public AttributeStorage { 204 using KeyTy = uint64_t; 205 206 BitmaskEnumStorage(KeyTy val) : value(val) {} 207 208 bool operator==(const KeyTy &key) const { return value == key; } 209 210 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, 211 const KeyTy &key) { 212 return new (allocator.allocate<BitmaskEnumStorage>()) 213 BitmaskEnumStorage(key); 214 } 215 216 KeyTy value = 0; 217 }; 218 } // namespace detail 219 } // namespace vector 220 } // namespace mlir 221 222 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind, 223 MLIRContext *context) { 224 return Base::get(context, static_cast<uint64_t>(kind)); 225 } 226 227 CombiningKind CombiningKindAttr::getKind() const { 228 return static_cast<CombiningKind>(getImpl()->value); 229 } 230 231 static constexpr const CombiningKind combiningKindsList[] = { 232 // clang-format off 233 CombiningKind::ADD, 234 CombiningKind::MUL, 235 CombiningKind::MINUI, 236 CombiningKind::MINSI, 237 CombiningKind::MINF, 238 CombiningKind::MAXUI, 239 CombiningKind::MAXSI, 240 CombiningKind::MAXF, 241 CombiningKind::AND, 242 CombiningKind::OR, 243 CombiningKind::XOR, 244 // clang-format on 245 }; 246 247 void CombiningKindAttr::print(AsmPrinter &printer) const { 248 printer << "<"; 249 auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { 250 return bitEnumContains(this->getKind(), kind); 251 }); 252 llvm::interleaveComma(kinds, printer, 253 [&](auto kind) { printer << stringifyEnum(kind); }); 254 printer << ">"; 255 } 256 257 Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) { 258 if (failed(parser.parseLess())) 259 return {}; 260 261 StringRef elemName; 262 if (failed(parser.parseKeyword(&elemName))) 263 return {}; 264 265 auto kind = symbolizeCombiningKind(elemName); 266 if (!kind) { 267 parser.emitError(parser.getNameLoc(), "Unknown combining kind: ") 268 << elemName; 269 return {}; 270 } 271 272 if (failed(parser.parseGreater())) 273 return {}; 274 275 return CombiningKindAttr::get(kind.getValue(), parser.getContext()); 276 } 277 278 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser, 279 Type type) const { 280 StringRef attrKind; 281 if (parser.parseKeyword(&attrKind)) 282 return {}; 283 284 if (attrKind == "kind") 285 return CombiningKindAttr::parse(parser, {}); 286 287 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; 288 return {}; 289 } 290 291 void VectorDialect::printAttribute(Attribute attr, 292 DialectAsmPrinter &os) const { 293 if (auto ck = attr.dyn_cast<CombiningKindAttr>()) { 294 os << "kind"; 295 ck.print(os); 296 return; 297 } 298 llvm_unreachable("Unknown attribute type"); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // VectorDialect 303 //===----------------------------------------------------------------------===// 304 305 void VectorDialect::initialize() { 306 addAttributes<CombiningKindAttr>(); 307 308 addOperations< 309 #define GET_OP_LIST 310 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 311 >(); 312 } 313 314 /// Materialize a single constant operation from a given attribute value with 315 /// the desired resultant type. 316 Operation *VectorDialect::materializeConstant(OpBuilder &builder, 317 Attribute value, Type type, 318 Location loc) { 319 return builder.create<arith::ConstantOp>(loc, type, value); 320 } 321 322 IntegerType vector::getVectorSubscriptType(Builder &builder) { 323 return builder.getIntegerType(64); 324 } 325 326 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, 327 ArrayRef<int64_t> values) { 328 return builder.getI64ArrayAttr(values); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // MultiDimReductionOp 333 //===----------------------------------------------------------------------===// 334 335 void vector::MultiDimReductionOp::build(OpBuilder &builder, 336 OperationState &result, Value source, 337 ArrayRef<bool> reductionMask, 338 CombiningKind kind) { 339 SmallVector<int64_t> reductionDims; 340 for (const auto &en : llvm::enumerate(reductionMask)) 341 if (en.value()) 342 reductionDims.push_back(en.index()); 343 build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims)); 344 } 345 346 LogicalResult MultiDimReductionOp::inferReturnTypes( 347 MLIRContext *, Optional<Location>, ValueRange operands, 348 DictionaryAttr attributes, RegionRange, 349 SmallVectorImpl<Type> &inferredReturnTypes) { 350 MultiDimReductionOp::Adaptor op(operands, attributes); 351 auto vectorType = op.source().getType().cast<VectorType>(); 352 SmallVector<int64_t> targetShape; 353 for (auto it : llvm::enumerate(vectorType.getShape())) 354 if (!llvm::any_of(op.reduction_dims().getValue(), [&](Attribute attr) { 355 return attr.cast<IntegerAttr>().getValue() == it.index(); 356 })) 357 targetShape.push_back(it.value()); 358 // TODO: update to also allow 0-d vectors when available. 359 if (targetShape.empty()) 360 inferredReturnTypes.push_back(vectorType.getElementType()); 361 else 362 inferredReturnTypes.push_back( 363 VectorType::get(targetShape, vectorType.getElementType())); 364 return success(); 365 } 366 367 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) { 368 // Single parallel dim, this is a noop. 369 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) 370 return source(); 371 return {}; 372 } 373 374 Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() { 375 return llvm::to_vector<4>(getSourceVectorType().getShape()); 376 } 377 378 //===----------------------------------------------------------------------===// 379 // ReductionOp 380 //===----------------------------------------------------------------------===// 381 382 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, 383 CombiningKind kind, Value vector) { 384 build(builder, result, kind, vector, /*acc=*/Value()); 385 } 386 387 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, 388 CombiningKind kind, Value vector, Value acc) { 389 build(builder, result, vector.getType().cast<VectorType>().getElementType(), 390 kind, vector, acc); 391 } 392 393 LogicalResult ReductionOp::verify() { 394 // Verify for 1-D vector. 395 int64_t rank = getVectorType().getRank(); 396 if (rank != 1) 397 return emitOpError("unsupported reduction rank: ") << rank; 398 399 // Verify supported reduction kind. 400 Type eltType = dest().getType(); 401 if (!isSupportedCombiningKind(kind(), eltType)) 402 return emitOpError("unsupported reduction type '") 403 << eltType << "' for kind '" << stringifyCombiningKind(kind()) 404 << "'"; 405 406 // Verify optional accumulator. 407 if (acc()) { 408 if (kind() != CombiningKind::ADD && kind() != CombiningKind::MUL) 409 return emitOpError("no accumulator for reduction kind: ") 410 << stringifyCombiningKind(kind()); 411 if (!eltType.isa<FloatType>()) 412 return emitOpError("no accumulator for type: ") << eltType; 413 } 414 415 return success(); 416 } 417 418 ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { 419 SmallVector<OpAsmParser::OperandType, 2> operandsInfo; 420 Type redType; 421 Type resType; 422 CombiningKindAttr kindAttr; 423 if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind", 424 result.attributes) || 425 parser.parseComma() || parser.parseOperandList(operandsInfo) || 426 parser.parseColonType(redType) || 427 parser.parseKeywordType("into", resType) || 428 (!operandsInfo.empty() && 429 parser.resolveOperand(operandsInfo[0], redType, result.operands)) || 430 (operandsInfo.size() > 1 && 431 parser.resolveOperand(operandsInfo[1], resType, result.operands)) || 432 parser.addTypeToList(resType, result.types)) 433 return failure(); 434 if (operandsInfo.empty() || operandsInfo.size() > 2) 435 return parser.emitError(parser.getNameLoc(), 436 "unsupported number of operands"); 437 return success(); 438 } 439 440 void ReductionOp::print(OpAsmPrinter &p) { 441 p << " "; 442 kindAttr().print(p); 443 p << ", " << vector(); 444 if (acc()) 445 p << ", " << acc(); 446 p << " : " << vector().getType() << " into " << dest().getType(); 447 } 448 449 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, 450 OpBuilder &builder, Location loc, 451 Value vector) { 452 switch (op) { 453 case arith::AtomicRMWKind::addf: 454 case arith::AtomicRMWKind::addi: 455 return builder.create<vector::ReductionOp>(vector.getLoc(), 456 CombiningKind::ADD, vector); 457 case arith::AtomicRMWKind::mulf: 458 case arith::AtomicRMWKind::muli: 459 return builder.create<vector::ReductionOp>(vector.getLoc(), 460 CombiningKind::MUL, vector); 461 case arith::AtomicRMWKind::minf: 462 return builder.create<vector::ReductionOp>(vector.getLoc(), 463 CombiningKind::MINF, vector); 464 case arith::AtomicRMWKind::mins: 465 return builder.create<vector::ReductionOp>(vector.getLoc(), 466 CombiningKind::MINSI, vector); 467 case arith::AtomicRMWKind::minu: 468 return builder.create<vector::ReductionOp>(vector.getLoc(), 469 CombiningKind::MINUI, vector); 470 case arith::AtomicRMWKind::maxf: 471 return builder.create<vector::ReductionOp>(vector.getLoc(), 472 CombiningKind::MAXF, vector); 473 case arith::AtomicRMWKind::maxs: 474 return builder.create<vector::ReductionOp>(vector.getLoc(), 475 CombiningKind::MAXSI, vector); 476 case arith::AtomicRMWKind::maxu: 477 return builder.create<vector::ReductionOp>(vector.getLoc(), 478 CombiningKind::MAXUI, vector); 479 // TODO: Add remaining reduction operations. 480 default: 481 (void)emitOptionalError(loc, "Reduction operation type not supported"); 482 break; 483 } 484 return nullptr; 485 } 486 487 //===----------------------------------------------------------------------===// 488 // ContractionOp 489 //===----------------------------------------------------------------------===// 490 491 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 492 Value lhs, Value rhs, Value acc, 493 ArrayRef<ArrayRef<AffineExpr>> indexingExprs, 494 ArrayRef<StringRef> iteratorTypes) { 495 result.addOperands({lhs, rhs, acc}); 496 result.addTypes(acc.getType()); 497 result.addAttribute(::mlir::getIndexingMapsAttrName(), 498 builder.getAffineMapArrayAttr( 499 AffineMap::inferFromExprList(indexingExprs))); 500 result.addAttribute(::mlir::getIteratorTypesAttrName(), 501 builder.getStrArrayAttr(iteratorTypes)); 502 } 503 504 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 505 Value lhs, Value rhs, Value acc, 506 ArrayAttr indexingMaps, 507 ArrayAttr iteratorTypes) { 508 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes, 509 ContractionOp::getDefaultKind()); 510 } 511 512 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 513 Value lhs, Value rhs, Value acc, 514 ArrayAttr indexingMaps, 515 ArrayAttr iteratorTypes, CombiningKind kind) { 516 result.addOperands({lhs, rhs, acc}); 517 result.addTypes(acc.getType()); 518 result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps); 519 result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes); 520 result.addAttribute(ContractionOp::getKindAttrStrName(), 521 CombiningKindAttr::get(kind, builder.getContext())); 522 } 523 524 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { 525 OpAsmParser::OperandType lhsInfo; 526 OpAsmParser::OperandType rhsInfo; 527 OpAsmParser::OperandType accInfo; 528 SmallVector<OpAsmParser::OperandType, 2> masksInfo; 529 SmallVector<Type, 2> types; 530 Type resultType; 531 auto loc = parser.getCurrentLocation(); 532 DictionaryAttr dictAttr; 533 // TODO: Unify linalg op attribute parsing. 534 if (parser.parseAttribute(dictAttr, "_", result.attributes) || 535 parser.parseOperand(lhsInfo) || parser.parseComma() || 536 parser.parseOperand(rhsInfo) || parser.parseComma() || 537 parser.parseOperand(accInfo) || 538 parser.parseTrailingOperandList(masksInfo) || 539 parser.parseOptionalAttrDict(result.attributes) || 540 parser.parseColonTypeList(types) || 541 parser.parseKeywordType("into", resultType) || 542 parser.resolveOperand(lhsInfo, types[0], result.operands) || 543 parser.resolveOperand(rhsInfo, types[1], result.operands) || 544 parser.resolveOperand(accInfo, resultType, result.operands) || 545 parser.addTypeToList(resultType, result.types)) 546 return failure(); 547 result.attributes.assign(dictAttr.getValue().begin(), 548 dictAttr.getValue().end()); 549 if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { 550 result.addAttribute(ContractionOp::getKindAttrStrName(), 551 CombiningKindAttr::get(ContractionOp::getDefaultKind(), 552 result.getContext())); 553 } 554 if (masksInfo.empty()) 555 return success(); 556 if (masksInfo.size() != 2) 557 return parser.emitError(parser.getNameLoc(), 558 "expected zero or exactly 2 vector mask operands"); 559 auto lhsType = types[0].cast<VectorType>(); 560 auto rhsType = types[1].cast<VectorType>(); 561 auto maskElementType = parser.getBuilder().getI1Type(); 562 std::array<Type, 2> maskTypes = { 563 VectorType::Builder(lhsType).setElementType(maskElementType), 564 VectorType::Builder(rhsType).setElementType(maskElementType)}; 565 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) 566 return failure(); 567 return success(); 568 } 569 570 void ContractionOp::print(OpAsmPrinter &p) { 571 // TODO: Unify printing code with linalg ops. 572 auto attrNames = getTraitAttrNames(); 573 llvm::StringSet<> traitAttrsSet; 574 traitAttrsSet.insert(attrNames.begin(), attrNames.end()); 575 SmallVector<NamedAttribute, 8> attrs; 576 for (auto attr : (*this)->getAttrs()) 577 if (traitAttrsSet.count(attr.getName().strref()) > 0) 578 attrs.push_back(attr); 579 580 auto dictAttr = DictionaryAttr::get(getContext(), attrs); 581 p << " " << dictAttr << " " << lhs() << ", "; 582 p << rhs() << ", " << acc(); 583 if (masks().size() == 2) 584 p << ", " << masks(); 585 586 p.printOptionalAttrDict((*this)->getAttrs(), attrNames); 587 p << " : " << lhs().getType() << ", " << rhs().getType() << " into " 588 << getResultType(); 589 } 590 591 static bool verifyDimMap(VectorType lhsType, VectorType rhsType, 592 const std::vector<std::pair<int64_t, int64_t>> &map) { 593 for (auto &dimPair : map) { 594 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || 595 dimPair.second < 0 || dimPair.second >= rhsType.getRank() || 596 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) 597 return false; 598 } 599 return true; 600 } 601 602 static LogicalResult verifyOutputShape( 603 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, 604 Type resType, 605 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, 606 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { 607 DenseSet<int64_t> lhsContractingDimSet; 608 DenseSet<int64_t> rhsContractingDimSet; 609 for (auto &dimPair : contractingDimMap) { 610 lhsContractingDimSet.insert(dimPair.first); 611 rhsContractingDimSet.insert(dimPair.second); 612 } 613 DenseSet<int64_t> rhsBatchDimSet; 614 for (auto &dimPair : batchDimMap) 615 rhsBatchDimSet.insert(dimPair.second); 616 617 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. 618 SmallVector<int64_t, 4> expectedResultDims; 619 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { 620 if (lhsContractingDimSet.count(i) > 0) 621 continue; 622 expectedResultDims.push_back(lhsType.getDimSize(i)); 623 } 624 625 // Add free dimensions from 'rhsType' to 'expectedResultDims'. 626 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { 627 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) 628 continue; 629 expectedResultDims.push_back(rhsType.getDimSize(i)); 630 } 631 632 // Verify 'expectedResultDims'. 633 if (expectedResultDims.empty()) { 634 // No batch or free dimension implies a scalar result. 635 if (resType.isa<VectorType>() || accType.isa<VectorType>()) 636 return op.emitOpError("invalid accumulator/result vector shape"); 637 } else { 638 // At least one batch or free dimension implies a vector result. 639 auto resVectorType = resType.dyn_cast<VectorType>(); 640 auto accVectorType = accType.dyn_cast<VectorType>(); 641 if (!resVectorType || !accVectorType) 642 return op.emitOpError("invalid accumulator/result vector shape"); 643 644 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector 645 // types fully define the result vector type. This assumes the affine maps 646 // are well-formed, which must have been verified already. 647 MLIRContext *ctx = op.getContext(); 648 AffineMap lhsMap = op.getIndexingMaps()[0]; 649 AffineMap rhsMap = op.getIndexingMaps()[1]; 650 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); 651 for (auto pair : 652 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { 653 VectorType v = pair.first; 654 auto map = pair.second; 655 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { 656 unsigned pos = map.getDimPosition(idx); 657 if (!extents[pos]) 658 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); 659 } 660 } 661 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) && 662 "expected extent along all dimensions."); 663 664 AffineMap resMap = op.getIndexingMaps()[2]; 665 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), 666 /*symCount=*/0, extents, ctx); 667 // Compose the resMap with the extentsMap, which is a constant map. 668 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); 669 assert(llvm::all_of( 670 expectedMap.getResults(), 671 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && 672 "expected constant extent along all dimensions."); 673 // Extract the expected shape and build the type. 674 auto expectedShape = llvm::to_vector<4>( 675 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { 676 return e.cast<AffineConstantExpr>().getValue(); 677 })); 678 auto expected = 679 VectorType::get(expectedShape, resVectorType.getElementType()); 680 if (resVectorType != expected || accVectorType != expected) 681 return op.emitOpError( 682 "invalid accumulator/result vector shape, expected: ") 683 << expected; 684 } 685 return success(); 686 } 687 688 LogicalResult ContractionOp::verify() { 689 auto lhsType = getLhsType(); 690 auto rhsType = getRhsType(); 691 auto accType = getAccType(); 692 auto resType = getResultType(); 693 694 // Verify that an indexing map was specified for each vector operand. 695 if (indexing_maps().size() != 3) 696 return emitOpError("expected an indexing map for each vector operand"); 697 698 // Verify that each index map has 'numIterators' inputs, no symbols, and 699 // that the number of map outputs equals the rank of its associated 700 // vector operand. 701 unsigned numIterators = iterator_types().getValue().size(); 702 for (const auto &it : llvm::enumerate(indexing_maps())) { 703 auto index = it.index(); 704 auto map = it.value(); 705 if (map.getNumSymbols() != 0) 706 return emitOpError("expected indexing map ") 707 << index << " to have no symbols"; 708 auto vectorType = getOperand(index).getType().dyn_cast<VectorType>(); 709 unsigned rank = vectorType ? vectorType.getShape().size() : 0; 710 // Verify that the map has the right number of inputs, outputs, and indices. 711 // This also correctly accounts for (..) -> () for rank-0 results. 712 if (map.getNumDims() != numIterators) 713 return emitOpError("expected indexing map ") 714 << index << " to have " << numIterators << " number of inputs"; 715 if (map.getNumResults() != rank) 716 return emitOpError("expected indexing map ") 717 << index << " to have " << rank << " number of outputs"; 718 if (!map.isProjectedPermutation()) 719 return emitOpError("expected indexing map ") 720 << index << " to be a projected permutation of its inputs"; 721 } 722 723 auto contractingDimMap = getContractingDimMap(); 724 auto batchDimMap = getBatchDimMap(); 725 726 // Verify at least one contracting dimension pair was specified. 727 if (contractingDimMap.empty()) 728 return emitOpError("expected at least one contracting dimension pair"); 729 730 // Verify contracting dimension map was properly constructed. 731 if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) 732 return emitOpError("invalid contracting dimension map"); 733 734 // Verify batch dimension map was properly constructed. 735 if (!verifyDimMap(lhsType, rhsType, batchDimMap)) 736 return emitOpError("invalid batch dimension map"); 737 738 // Verify 'accType' and 'resType' shape. 739 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, 740 contractingDimMap, batchDimMap))) 741 return failure(); 742 743 // Verify that either two vector masks are set or none are set. 744 auto lhsMaskType = getLHSVectorMaskType(); 745 auto rhsMaskType = getRHSVectorMaskType(); 746 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) 747 return emitOpError("invalid number of vector masks specified"); 748 if (lhsMaskType && rhsMaskType) { 749 // Verify mask rank == argument rank. 750 if (lhsMaskType.getShape().size() != lhsType.getShape().size() || 751 rhsMaskType.getShape().size() != rhsType.getShape().size()) 752 return emitOpError("invalid vector mask rank"); 753 } 754 755 // Verify supported combining kind. 756 auto vectorType = resType.dyn_cast<VectorType>(); 757 auto elementType = vectorType ? vectorType.getElementType() : resType; 758 if (!isSupportedCombiningKind(kind(), elementType)) 759 return emitOpError("unsupported contraction type"); 760 761 return success(); 762 } 763 764 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() { 765 static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(), 766 ::mlir::getIteratorTypesAttrName(), 767 ContractionOp::getKindAttrStrName()}; 768 return llvm::makeArrayRef(names); 769 } 770 771 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { 772 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) 773 if (targetExpr == map.getResult(i)) 774 return i; 775 return -1; 776 } 777 778 static std::vector<std::pair<int64_t, int64_t>> 779 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, 780 StringRef targetIteratorTypeName, MLIRContext *context) { 781 std::vector<std::pair<int64_t, int64_t>> dimMap; 782 for (const auto &it : llvm::enumerate(iteratorTypes)) { 783 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 784 if (iteratorTypeName != targetIteratorTypeName) 785 continue; 786 // Search lhs/rhs map results for 'targetExpr'. 787 auto targetExpr = getAffineDimExpr(it.index(), context); 788 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); 789 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); 790 if (lhsDim >= 0 && rhsDim >= 0) 791 dimMap.emplace_back(lhsDim, rhsDim); 792 } 793 return dimMap; 794 } 795 796 void ContractionOp::getIterationBounds( 797 SmallVectorImpl<int64_t> &iterationBounds) { 798 auto lhsShape = getLhsType().getShape(); 799 auto resVectorType = getResultType().dyn_cast<VectorType>(); 800 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 801 SmallVector<int64_t, 2> iterationShape; 802 for (const auto &it : llvm::enumerate(iterator_types())) { 803 // Search lhs/rhs map results for 'targetExpr'. 804 auto targetExpr = getAffineDimExpr(it.index(), getContext()); 805 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 806 if (iteratorTypeName == getReductionIteratorTypeName()) { 807 // Get reduction dim size from lhs shape (same size in rhsShape). 808 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); 809 assert(lhsDimIndex >= 0); 810 iterationBounds.push_back(lhsShape[lhsDimIndex]); 811 continue; 812 } 813 // Get parallel dimension size from result shape. 814 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); 815 assert(resDimIndex >= 0); 816 assert(resVectorType != nullptr); 817 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); 818 } 819 } 820 821 void ContractionOp::getIterationIndexMap( 822 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { 823 unsigned numMaps = indexing_maps().size(); 824 iterationIndexMap.resize(numMaps); 825 for (const auto &it : llvm::enumerate(indexing_maps())) { 826 auto index = it.index(); 827 auto map = it.value(); 828 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 829 auto dim = map.getResult(i).cast<AffineDimExpr>(); 830 iterationIndexMap[index][dim.getPosition()] = i; 831 } 832 } 833 } 834 835 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { 836 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 837 return getDimMap(indexingMaps, iterator_types(), 838 getReductionIteratorTypeName(), getContext()); 839 } 840 841 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { 842 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 843 return getDimMap(indexingMaps, iterator_types(), 844 getParallelIteratorTypeName(), getContext()); 845 } 846 847 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { 848 SmallVector<int64_t, 4> shape; 849 getIterationBounds(shape); 850 return shape; 851 } 852 853 /// Return a fused vector::ContractionOp which represents a patterns such as: 854 /// 855 /// ```mlir 856 /// %c0 = vector.constant 0: ... 857 /// %c = vector.contract %a, %b, %c0: ... 858 /// %e = add %c, %d: ... 859 /// ``` 860 /// 861 /// by: 862 /// 863 /// ```mlir 864 /// %e = vector.contract %a, %b, %d: ... 865 /// ``` 866 /// 867 /// Return null if the canonicalization does not apply. 868 // TODO: This should be a folding of Add into Contract in core but while they 869 // live in different dialects, it is not possible without unnatural 870 // dependencies. 871 template <typename AddOpType> 872 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { 873 using OpRewritePattern<AddOpType>::OpRewritePattern; 874 875 LogicalResult matchAndRewrite(AddOpType addOp, 876 PatternRewriter &rewriter) const override { 877 auto canonicalize = [&](Value maybeContraction, 878 Value otherOperand) -> vector::ContractionOp { 879 vector::ContractionOp contractionOp = 880 dyn_cast_or_null<vector::ContractionOp>( 881 maybeContraction.getDefiningOp()); 882 if (!contractionOp) 883 return vector::ContractionOp(); 884 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>( 885 contractionOp.acc().getDefiningOp())) { 886 if (maybeZero.getValue() == 887 rewriter.getZeroAttr(contractionOp.acc().getType())) { 888 BlockAndValueMapping bvm; 889 bvm.map(contractionOp.acc(), otherOperand); 890 auto newContraction = 891 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm)); 892 rewriter.replaceOp(addOp, newContraction.getResult()); 893 return newContraction; 894 } 895 } 896 return vector::ContractionOp(); 897 }; 898 899 Value a = addOp->getOperand(0), b = addOp->getOperand(1); 900 vector::ContractionOp contract = canonicalize(a, b); 901 contract = contract ? contract : canonicalize(b, a); 902 return contract ? success() : failure(); 903 } 904 }; 905 906 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, 907 MLIRContext *context) { 908 results.add<CanonicalizeContractAdd<arith::AddIOp>, 909 CanonicalizeContractAdd<arith::AddFOp>>(context); 910 } 911 912 //===----------------------------------------------------------------------===// 913 // ExtractElementOp 914 //===----------------------------------------------------------------------===// 915 916 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 917 Value source) { 918 result.addOperands({source}); 919 result.addTypes(source.getType().cast<VectorType>().getElementType()); 920 } 921 922 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 923 Value source, Value position) { 924 result.addOperands({source, position}); 925 result.addTypes(source.getType().cast<VectorType>().getElementType()); 926 } 927 928 LogicalResult vector::ExtractElementOp::verify() { 929 VectorType vectorType = getVectorType(); 930 if (vectorType.getRank() == 0) { 931 if (position()) 932 return emitOpError("expected position to be empty with 0-D vector"); 933 return success(); 934 } 935 if (vectorType.getRank() != 1) 936 return emitOpError("unexpected >1 vector rank"); 937 if (!position()) 938 return emitOpError("expected position for 1-D vector"); 939 return success(); 940 } 941 942 //===----------------------------------------------------------------------===// 943 // ExtractOp 944 //===----------------------------------------------------------------------===// 945 946 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 947 Value source, ArrayRef<int64_t> position) { 948 build(builder, result, source, getVectorSubscriptAttr(builder, position)); 949 } 950 951 // Convenience builder which assumes the values are constant indices. 952 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 953 Value source, ValueRange position) { 954 SmallVector<int64_t, 4> positionConstants = 955 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 956 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 957 })); 958 build(builder, result, source, positionConstants); 959 } 960 961 LogicalResult 962 ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>, 963 ValueRange operands, DictionaryAttr attributes, 964 RegionRange, 965 SmallVectorImpl<Type> &inferredReturnTypes) { 966 ExtractOp::Adaptor op(operands, attributes); 967 auto vectorType = op.vector().getType().cast<VectorType>(); 968 if (static_cast<int64_t>(op.position().size()) == vectorType.getRank()) { 969 inferredReturnTypes.push_back(vectorType.getElementType()); 970 } else { 971 auto n = std::min<size_t>(op.position().size(), vectorType.getRank() - 1); 972 inferredReturnTypes.push_back(VectorType::get( 973 vectorType.getShape().drop_front(n), vectorType.getElementType())); 974 } 975 return success(); 976 } 977 978 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 979 // Allow extracting 1-element vectors instead of scalars. 980 auto isCompatible = [](TypeRange l, TypeRange r) { 981 auto vectorType = l.front().dyn_cast<VectorType>(); 982 return vectorType && vectorType.getShape().equals({1}) && 983 vectorType.getElementType() == r.front(); 984 }; 985 if (l.size() == 1 && r.size() == 1 && 986 (isCompatible(l, r) || isCompatible(r, l))) 987 return true; 988 return l == r; 989 } 990 991 LogicalResult vector::ExtractOp::verify() { 992 auto positionAttr = position().getValue(); 993 if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank())) 994 return emitOpError( 995 "expected position attribute of rank smaller than vector rank"); 996 for (const auto &en : llvm::enumerate(positionAttr)) { 997 auto attr = en.value().dyn_cast<IntegerAttr>(); 998 if (!attr || attr.getInt() < 0 || 999 attr.getInt() >= getVectorType().getDimSize(en.index())) 1000 return emitOpError("expected position attribute #") 1001 << (en.index() + 1) 1002 << " to be a non-negative integer smaller than the corresponding " 1003 "vector dimension"; 1004 } 1005 return success(); 1006 } 1007 1008 template <typename IntType> 1009 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 1010 return llvm::to_vector<4>(llvm::map_range( 1011 arrayAttr.getAsRange<IntegerAttr>(), 1012 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 1013 } 1014 1015 /// Fold the result of chains of ExtractOp in place by simply concatenating the 1016 /// positions. 1017 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { 1018 if (!extractOp.vector().getDefiningOp<ExtractOp>()) 1019 return failure(); 1020 1021 SmallVector<int64_t, 4> globalPosition; 1022 ExtractOp currentOp = extractOp; 1023 auto extrPos = extractVector<int64_t>(currentOp.position()); 1024 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1025 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) { 1026 currentOp = nextOp; 1027 auto extrPos = extractVector<int64_t>(currentOp.position()); 1028 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1029 } 1030 extractOp.setOperand(currentOp.vector()); 1031 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1032 OpBuilder b(extractOp.getContext()); 1033 std::reverse(globalPosition.begin(), globalPosition.end()); 1034 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1035 b.getI64ArrayAttr(globalPosition)); 1036 return success(); 1037 } 1038 1039 namespace { 1040 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. 1041 /// Walk back a chain of InsertOp/TransposeOp until we hit a match. 1042 /// Compose TransposeOp permutations as we walk back. 1043 /// This helper class keeps an updated extraction position `extractPosition` 1044 /// with extra trailing sentinels. 1045 /// The sentinels encode the internal transposition status of the result vector. 1046 /// As we iterate, extractPosition is permuted and updated. 1047 class ExtractFromInsertTransposeChainState { 1048 public: 1049 ExtractFromInsertTransposeChainState(ExtractOp e); 1050 1051 /// Iterate over producing insert and transpose ops until we find a fold. 1052 Value fold(); 1053 1054 private: 1055 /// Return true if the vector at position `a` is contained within the vector 1056 /// at position `b`. Under insert/extract semantics, this is the same as `a` 1057 /// is a prefix of `b`. 1058 template <typename ContainerA, typename ContainerB> 1059 bool isContainedWithin(const ContainerA &a, const ContainerB &b) { 1060 return a.size() <= b.size() && 1061 std::equal(a.begin(), a.begin() + a.size(), b.begin()); 1062 } 1063 1064 /// Return true if the vector at position `a` intersects the vector at 1065 /// position `b`. Under insert/extract semantics, this is the same as equality 1066 /// of all entries of `a` that are >=0 with the corresponding entries of b. 1067 /// Comparison is on the common prefix (i.e. zip). 1068 template <typename ContainerA, typename ContainerB> 1069 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { 1070 for (auto it : llvm::zip(a, b)) { 1071 if (std::get<0>(it) < 0 || std::get<0>(it) < 0) 1072 continue; 1073 if (std::get<0>(it) != std::get<1>(it)) 1074 return false; 1075 } 1076 return true; 1077 } 1078 1079 /// Folding is only possible in the absence of an internal permutation in the 1080 /// result vector. 1081 bool canFold() { 1082 return (sentinels == 1083 makeArrayRef(extractPosition).drop_front(extractedRank)); 1084 } 1085 1086 // Helper to get the next defining op of interest. 1087 void updateStateForNextIteration(Value v) { 1088 nextInsertOp = v.getDefiningOp<vector::InsertOp>(); 1089 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>(); 1090 }; 1091 1092 // Case 1. If we hit a transpose, just compose the map and iterate. 1093 // Invariant: insert + transpose do not change rank, we can always compose. 1094 LogicalResult handleTransposeOp(); 1095 1096 // Case 2: the insert position matches extractPosition exactly, early return. 1097 LogicalResult handleInsertOpWithMatchingPos(Value &res); 1098 1099 /// Case 3: if the insert position is a prefix of extractPosition, extract a 1100 /// portion of the source of the insert. 1101 /// Example: 1102 /// ``` 1103 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> 1104 /// // extractPosition == [1, 2, 3] 1105 /// %ext = vector.extract %ins[1, 0]: vector<3x4x5> 1106 /// // can fold to vector.extract %source[0, 3] 1107 /// %ext = vector.extract %source[3]: vector<5x6> 1108 /// ``` 1109 /// To traverse through %source, we need to set the leading dims to 0 and 1110 /// drop the extra leading dims. 1111 /// This method updates the internal state. 1112 LogicalResult handleInsertOpWithPrefixPos(Value &res); 1113 1114 /// Try to fold in place to extract(source, extractPosition) and return the 1115 /// folded result. Return null if folding is not possible (e.g. due to an 1116 /// internal tranposition in the result). 1117 Value tryToFoldExtractOpInPlace(Value source); 1118 1119 ExtractOp extractOp; 1120 int64_t vectorRank; 1121 int64_t extractedRank; 1122 1123 InsertOp nextInsertOp; 1124 TransposeOp nextTransposeOp; 1125 1126 /// Sentinel values that encode the internal permutation status of the result. 1127 /// They are set to (-1, ... , -k) at the beginning and appended to 1128 /// `extractPosition`. 1129 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to 1130 /// ensure that there is no internal transposition. 1131 /// Internal transposition cannot be accounted for with a folding pattern. 1132 // TODO: We could relax the internal transposition with an extra transposition 1133 // operation in a future canonicalizer. 1134 SmallVector<int64_t> sentinels; 1135 SmallVector<int64_t> extractPosition; 1136 }; 1137 } // namespace 1138 1139 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( 1140 ExtractOp e) 1141 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), 1142 extractedRank(extractOp.position().size()) { 1143 assert(vectorRank >= extractedRank && "extracted pos overflow"); 1144 sentinels.reserve(vectorRank - extractedRank); 1145 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) 1146 sentinels.push_back(-(i + 1)); 1147 extractPosition = extractVector<int64_t>(extractOp.position()); 1148 llvm::append_range(extractPosition, sentinels); 1149 } 1150 1151 // Case 1. If we hit a transpose, just compose the map and iterate. 1152 // Invariant: insert + transpose do not change rank, we can always compose. 1153 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { 1154 if (!nextTransposeOp) 1155 return failure(); 1156 auto permutation = extractVector<unsigned>(nextTransposeOp.transp()); 1157 AffineMap m = inversePermutation( 1158 AffineMap::getPermutationMap(permutation, extractOp.getContext())); 1159 extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); 1160 return success(); 1161 } 1162 1163 // Case 2: the insert position matches extractPosition exactly, early return. 1164 LogicalResult 1165 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( 1166 Value &res) { 1167 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1168 if (makeArrayRef(insertedPos) != 1169 llvm::makeArrayRef(extractPosition).take_front(extractedRank)) 1170 return failure(); 1171 // Case 2.a. early-exit fold. 1172 res = nextInsertOp.source(); 1173 // Case 2.b. if internal transposition is present, canFold will be false. 1174 return success(); 1175 } 1176 1177 /// Case 3: if inserted position is a prefix of extractPosition, 1178 /// extract a portion of the source of the insertion. 1179 /// This method updates the internal state. 1180 LogicalResult 1181 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { 1182 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1183 if (!isContainedWithin(insertedPos, extractPosition)) 1184 return failure(); 1185 // Set leading dims to zero. 1186 std::fill_n(extractPosition.begin(), insertedPos.size(), 0); 1187 // Drop extra leading dims. 1188 extractPosition.erase(extractPosition.begin(), 1189 extractPosition.begin() + insertedPos.size()); 1190 extractedRank = extractPosition.size() - sentinels.size(); 1191 // Case 3.a. early-exit fold (break and delegate to post-while path). 1192 res = nextInsertOp.source(); 1193 // Case 3.b. if internal transposition is present, canFold will be false. 1194 return success(); 1195 } 1196 1197 /// Try to fold in place to extract(source, extractPosition) and return the 1198 /// folded result. Return null if folding is not possible (e.g. due to an 1199 /// internal tranposition in the result). 1200 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( 1201 Value source) { 1202 // If we can't fold (either internal transposition, or nothing to fold), bail. 1203 bool nothingToFold = (source == extractOp.vector()); 1204 if (nothingToFold || !canFold()) 1205 return Value(); 1206 // Otherwise, fold by updating the op inplace and return its result. 1207 OpBuilder b(extractOp.getContext()); 1208 extractOp->setAttr( 1209 extractOp.positionAttrName(), 1210 b.getI64ArrayAttr( 1211 makeArrayRef(extractPosition).take_front(extractedRank))); 1212 extractOp.vectorMutable().assign(source); 1213 return extractOp.getResult(); 1214 } 1215 1216 /// Iterate over producing insert and transpose ops until we find a fold. 1217 Value ExtractFromInsertTransposeChainState::fold() { 1218 Value valueToExtractFrom = extractOp.vector(); 1219 updateStateForNextIteration(valueToExtractFrom); 1220 while (nextInsertOp || nextTransposeOp) { 1221 // Case 1. If we hit a transpose, just compose the map and iterate. 1222 // Invariant: insert + transpose do not change rank, we can always compose. 1223 if (succeeded(handleTransposeOp())) { 1224 valueToExtractFrom = nextTransposeOp.vector(); 1225 updateStateForNextIteration(valueToExtractFrom); 1226 continue; 1227 } 1228 1229 Value result; 1230 // Case 2: the position match exactly. 1231 if (succeeded(handleInsertOpWithMatchingPos(result))) 1232 return result; 1233 1234 // Case 3: if the inserted position is a prefix of extractPosition, we can 1235 // just extract a portion of the source of the insert. 1236 if (succeeded(handleInsertOpWithPrefixPos(result))) 1237 return tryToFoldExtractOpInPlace(result); 1238 1239 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel 1240 // values. This is a more difficult case and we bail. 1241 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1242 if (isContainedWithin(extractPosition, insertedPos) || 1243 intersectsWhereNonNegative(extractPosition, insertedPos)) 1244 return Value(); 1245 1246 // Case 5: No intersection, we forward the extract to insertOp.dest(). 1247 valueToExtractFrom = nextInsertOp.dest(); 1248 updateStateForNextIteration(valueToExtractFrom); 1249 } 1250 // If after all this we can fold, go for it. 1251 return tryToFoldExtractOpInPlace(valueToExtractFrom); 1252 } 1253 1254 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. 1255 static Value foldExtractFromBroadcast(ExtractOp extractOp) { 1256 Operation *defOp = extractOp.vector().getDefiningOp(); 1257 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1258 return Value(); 1259 Value source = defOp->getOperand(0); 1260 if (extractOp.getType() == source.getType()) 1261 return source; 1262 auto getRank = [](Type type) { 1263 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1264 }; 1265 unsigned broadcastSrcRank = getRank(source.getType()); 1266 unsigned extractResultRank = getRank(extractOp.getType()); 1267 if (extractResultRank < broadcastSrcRank) { 1268 auto extractPos = extractVector<int64_t>(extractOp.position()); 1269 unsigned rankDiff = broadcastSrcRank - extractResultRank; 1270 extractPos.erase( 1271 extractPos.begin(), 1272 std::next(extractPos.begin(), extractPos.size() - rankDiff)); 1273 extractOp.setOperand(source); 1274 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1275 OpBuilder b(extractOp.getContext()); 1276 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1277 b.getI64ArrayAttr(extractPos)); 1278 return extractOp.getResult(); 1279 } 1280 return Value(); 1281 } 1282 1283 // Fold extractOp with source coming from ShapeCast op. 1284 static Value foldExtractFromShapeCast(ExtractOp extractOp) { 1285 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); 1286 if (!shapeCastOp) 1287 return Value(); 1288 // Get the nth dimension size starting from lowest dimension. 1289 auto getDimReverse = [](VectorType type, int64_t n) { 1290 return type.getShape().take_back(n + 1).front(); 1291 }; 1292 int64_t destinationRank = 1293 extractOp.getType().isa<VectorType>() 1294 ? extractOp.getType().cast<VectorType>().getRank() 1295 : 0; 1296 if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) 1297 return Value(); 1298 if (destinationRank > 0) { 1299 auto destinationType = extractOp.getResult().getType().cast<VectorType>(); 1300 for (int64_t i = 0; i < destinationRank; i++) { 1301 // The lowest dimension of of the destination must match the lowest 1302 // dimension of the shapecast op source. 1303 // TODO: This case could be support in a canonicalization pattern. 1304 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != 1305 getDimReverse(destinationType, i)) 1306 return Value(); 1307 } 1308 } 1309 // Extract the strides associated with the extract op vector source. Then use 1310 // this to calculate a linearized position for the extract. 1311 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1312 std::reverse(extractedPos.begin(), extractedPos.end()); 1313 SmallVector<int64_t, 4> strides; 1314 int64_t stride = 1; 1315 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { 1316 strides.push_back(stride); 1317 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); 1318 } 1319 1320 int64_t position = linearize(extractedPos, strides); 1321 // Then extract the strides associated to the shapeCast op vector source and 1322 // delinearize the position using those strides. 1323 SmallVector<int64_t, 4> newStrides; 1324 int64_t numDimension = 1325 shapeCastOp.getSourceVectorType().getRank() - destinationRank; 1326 stride = 1; 1327 for (int64_t i = 0; i < numDimension; i++) { 1328 newStrides.push_back(stride); 1329 stride *= 1330 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); 1331 } 1332 std::reverse(newStrides.begin(), newStrides.end()); 1333 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position); 1334 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1335 OpBuilder b(extractOp.getContext()); 1336 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1337 b.getI64ArrayAttr(newPosition)); 1338 extractOp.setOperand(shapeCastOp.source()); 1339 return extractOp.getResult(); 1340 } 1341 1342 /// Fold an ExtractOp from ExtractStridedSliceOp. 1343 static Value foldExtractFromExtractStrided(ExtractOp extractOp) { 1344 auto extractStridedSliceOp = 1345 extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>(); 1346 if (!extractStridedSliceOp) 1347 return Value(); 1348 // Return if 'extractStridedSliceOp' has non-unit strides. 1349 if (extractStridedSliceOp.hasNonUnitStrides()) 1350 return Value(); 1351 1352 // Trim offsets for dimensions fully extracted. 1353 auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets()); 1354 while (!sliceOffsets.empty()) { 1355 size_t lastOffset = sliceOffsets.size() - 1; 1356 if (sliceOffsets.back() != 0 || 1357 extractStridedSliceOp.getType().getDimSize(lastOffset) != 1358 extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) 1359 break; 1360 sliceOffsets.pop_back(); 1361 } 1362 unsigned destinationRank = 0; 1363 if (auto vecType = extractOp.getType().dyn_cast<VectorType>()) 1364 destinationRank = vecType.getRank(); 1365 // The dimensions of the result need to be untouched by the 1366 // extractStridedSlice op. 1367 if (destinationRank > 1368 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) 1369 return Value(); 1370 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1371 assert(extractedPos.size() >= sliceOffsets.size()); 1372 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) 1373 extractedPos[i] = extractedPos[i] + sliceOffsets[i]; 1374 extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); 1375 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1376 OpBuilder b(extractOp.getContext()); 1377 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1378 b.getI64ArrayAttr(extractedPos)); 1379 return extractOp.getResult(); 1380 } 1381 1382 /// Fold extract_op fed from a chain of insertStridedSlice ops. 1383 static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { 1384 int64_t destinationRank = op.getType().isa<VectorType>() 1385 ? op.getType().cast<VectorType>().getRank() 1386 : 0; 1387 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 1388 while (insertOp) { 1389 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - 1390 insertOp.getSourceVectorType().getRank(); 1391 if (destinationRank > insertOp.getSourceVectorType().getRank()) 1392 return Value(); 1393 auto insertOffsets = extractVector<int64_t>(insertOp.offsets()); 1394 auto extractOffsets = extractVector<int64_t>(op.position()); 1395 1396 if (llvm::any_of(insertOp.strides(), [](Attribute attr) { 1397 return attr.cast<IntegerAttr>().getInt() != 1; 1398 })) 1399 return Value(); 1400 bool disjoint = false; 1401 SmallVector<int64_t, 4> offsetDiffs; 1402 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 1403 int64_t start = insertOffsets[dim]; 1404 int64_t size = 1405 (dim < insertRankDiff) 1406 ? 1 1407 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); 1408 int64_t end = start + size; 1409 int64_t offset = extractOffsets[dim]; 1410 // Check if the start of the extract offset is in the interval inserted. 1411 if (start <= offset && offset < end) { 1412 if (dim >= insertRankDiff) 1413 offsetDiffs.push_back(offset - start); 1414 continue; 1415 } 1416 disjoint = true; 1417 break; 1418 } 1419 // The extract element chunk overlap with the vector inserted. 1420 if (!disjoint) { 1421 // If any of the inner dimensions are only partially inserted we have a 1422 // partial overlap. 1423 int64_t srcRankDiff = 1424 insertOp.getSourceVectorType().getRank() - destinationRank; 1425 for (int64_t i = 0; i < destinationRank; i++) { 1426 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != 1427 insertOp.getDestVectorType().getDimSize(i + srcRankDiff + 1428 insertRankDiff)) 1429 return Value(); 1430 } 1431 op.vectorMutable().assign(insertOp.source()); 1432 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1433 OpBuilder b(op.getContext()); 1434 op->setAttr(ExtractOp::getPositionAttrStrName(), 1435 b.getI64ArrayAttr(offsetDiffs)); 1436 return op.getResult(); 1437 } 1438 // If the chunk extracted is disjoint from the chunk inserted, keep 1439 // looking in the insert chain. 1440 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 1441 } 1442 return Value(); 1443 } 1444 1445 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { 1446 if (position().empty()) 1447 return vector(); 1448 if (succeeded(foldExtractOpFromExtractChain(*this))) 1449 return getResult(); 1450 if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) 1451 return res; 1452 if (auto res = foldExtractFromBroadcast(*this)) 1453 return res; 1454 if (auto res = foldExtractFromShapeCast(*this)) 1455 return res; 1456 if (auto val = foldExtractFromExtractStrided(*this)) 1457 return val; 1458 if (auto val = foldExtractStridedOpFromInsertChain(*this)) 1459 return val; 1460 return OpFoldResult(); 1461 } 1462 1463 namespace { 1464 1465 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. 1466 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { 1467 public: 1468 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1469 1470 LogicalResult matchAndRewrite(ExtractOp extractOp, 1471 PatternRewriter &rewriter) const override { 1472 Operation *defOp = extractOp.vector().getDefiningOp(); 1473 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1474 return failure(); 1475 Value source = defOp->getOperand(0); 1476 if (extractOp.getType() == source.getType()) 1477 return failure(); 1478 auto getRank = [](Type type) { 1479 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1480 }; 1481 unsigned broadcastSrcRank = getRank(source.getType()); 1482 unsigned extractResultRank = getRank(extractOp.getType()); 1483 // We only consider the case where the rank of the source is smaller than 1484 // the rank of the extract dst. The other cases are handled in the folding 1485 // patterns. 1486 if (extractResultRank <= broadcastSrcRank) 1487 return failure(); 1488 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1489 extractOp, extractOp.getType(), source); 1490 return success(); 1491 } 1492 }; 1493 1494 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. 1495 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> { 1496 public: 1497 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1498 1499 LogicalResult matchAndRewrite(ExtractOp extractOp, 1500 PatternRewriter &rewriter) const override { 1501 // Return if 'extractStridedSliceOp' operand is not defined by a 1502 // ConstantOp. 1503 auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>(); 1504 if (!constantOp) 1505 return failure(); 1506 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 1507 if (!dense) 1508 return failure(); 1509 Attribute newAttr = dense.getSplatValue<Attribute>(); 1510 if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>()) 1511 newAttr = DenseElementsAttr::get(vecDstType, newAttr); 1512 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); 1513 return success(); 1514 } 1515 }; 1516 1517 } // namespace 1518 1519 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, 1520 MLIRContext *context) { 1521 results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context); 1522 } 1523 1524 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 1525 SmallVectorImpl<int64_t> &results) { 1526 for (auto attr : arrayAttr) 1527 results.push_back(attr.cast<IntegerAttr>().getInt()); 1528 } 1529 1530 //===----------------------------------------------------------------------===// 1531 // ExtractMapOp 1532 //===----------------------------------------------------------------------===// 1533 1534 void ExtractMapOp::build(OpBuilder &builder, OperationState &result, 1535 Value vector, ValueRange ids, 1536 ArrayRef<int64_t> multiplicity, 1537 AffineMap permutationMap) { 1538 assert(ids.size() == multiplicity.size() && 1539 ids.size() == permutationMap.getNumResults()); 1540 assert(permutationMap.isProjectedPermutation()); 1541 VectorType type = vector.getType().cast<VectorType>(); 1542 SmallVector<int64_t, 4> newShape(type.getShape().begin(), 1543 type.getShape().end()); 1544 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) { 1545 AffineExpr expr = permutationMap.getResult(i); 1546 auto dim = expr.cast<AffineDimExpr>(); 1547 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i]; 1548 } 1549 VectorType resultType = VectorType::get(newShape, type.getElementType()); 1550 ExtractMapOp::build(builder, result, resultType, vector, ids); 1551 } 1552 1553 LogicalResult ExtractMapOp::verify() { 1554 if (getSourceVectorType().getRank() != getResultType().getRank()) 1555 return emitOpError("expected source and destination vectors of same rank"); 1556 unsigned numId = 0; 1557 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) { 1558 if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) != 1559 0) 1560 return emitOpError("source vector dimensions must be a multiple of " 1561 "destination vector dimensions"); 1562 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1563 numId++; 1564 } 1565 if (numId != ids().size()) 1566 return emitOpError("expected number of ids must match the number of " 1567 "dimensions distributed"); 1568 return success(); 1569 } 1570 1571 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) { 1572 auto insert = vector().getDefiningOp<vector::InsertMapOp>(); 1573 if (insert == nullptr || getType() != insert.vector().getType() || 1574 ids() != insert.ids()) 1575 return {}; 1576 return insert.vector(); 1577 } 1578 1579 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) { 1580 assert(multiplicity.empty()); 1581 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) { 1582 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1583 multiplicity.push_back(getSourceVectorType().getDimSize(i) / 1584 getResultType().getDimSize(i)); 1585 } 1586 } 1587 1588 template <typename MapOp> 1589 AffineMap calculateImplicitMap(MapOp op) { 1590 SmallVector<AffineExpr, 4> perm; 1591 // Check which dimension have a multiplicity greater than 1 and associated 1592 // them to the IDs in order. 1593 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) { 1594 if (op.getSourceVectorType().getDimSize(i) != 1595 op.getResultType().getDimSize(i)) 1596 perm.push_back(getAffineDimExpr(i, op.getContext())); 1597 } 1598 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm, 1599 op.getContext()); 1600 return map; 1601 } 1602 1603 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } 1604 1605 //===----------------------------------------------------------------------===// 1606 // FmaOp 1607 //===----------------------------------------------------------------------===// 1608 1609 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { 1610 return llvm::to_vector<4>(getVectorType().getShape()); 1611 } 1612 1613 //===----------------------------------------------------------------------===// 1614 // BroadcastOp 1615 //===----------------------------------------------------------------------===// 1616 1617 BroadcastableToResult 1618 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, 1619 std::pair<int, int> *mismatchingDims) { 1620 // Broadcast scalar to vector of the same element type. 1621 if (srcType.isIntOrIndexOrFloat() && dstVectorType && 1622 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) 1623 return BroadcastableToResult::Success; 1624 // From now on, only vectors broadcast. 1625 VectorType srcVectorType = srcType.dyn_cast<VectorType>(); 1626 if (!srcVectorType) 1627 return BroadcastableToResult::SourceTypeNotAVector; 1628 1629 int64_t srcRank = srcVectorType.getRank(); 1630 int64_t dstRank = dstVectorType.getRank(); 1631 if (srcRank > dstRank) 1632 return BroadcastableToResult::SourceRankHigher; 1633 // Source has an exact match or singleton value for all trailing dimensions 1634 // (all leading dimensions are simply duplicated). 1635 int64_t lead = dstRank - srcRank; 1636 for (int64_t r = 0; r < srcRank; ++r) { 1637 int64_t srcDim = srcVectorType.getDimSize(r); 1638 int64_t dstDim = dstVectorType.getDimSize(lead + r); 1639 if (srcDim != 1 && srcDim != dstDim) { 1640 if (mismatchingDims) { 1641 mismatchingDims->first = srcDim; 1642 mismatchingDims->second = dstDim; 1643 } 1644 return BroadcastableToResult::DimensionMismatch; 1645 } 1646 } 1647 1648 return BroadcastableToResult::Success; 1649 } 1650 1651 LogicalResult BroadcastOp::verify() { 1652 std::pair<int, int> mismatchingDims; 1653 BroadcastableToResult res = 1654 isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); 1655 if (res == BroadcastableToResult::Success) 1656 return success(); 1657 if (res == BroadcastableToResult::SourceRankHigher) 1658 return emitOpError("source rank higher than destination rank"); 1659 if (res == BroadcastableToResult::DimensionMismatch) 1660 return emitOpError("dimension mismatch (") 1661 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; 1662 if (res == BroadcastableToResult::SourceTypeNotAVector) 1663 return emitOpError("source type is not a vector"); 1664 llvm_unreachable("unexpected vector.broadcast op error"); 1665 } 1666 1667 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 1668 if (getSourceType() == getVectorType()) 1669 return source(); 1670 if (!operands[0]) 1671 return {}; 1672 auto vectorType = getVectorType(); 1673 if (operands[0].getType().isIntOrIndexOrFloat()) 1674 return DenseElementsAttr::get(vectorType, operands[0]); 1675 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>()) 1676 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()); 1677 return {}; 1678 } 1679 1680 namespace { 1681 1682 // Fold broadcast1(broadcast2(x)) into broadcast1(x). 1683 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { 1684 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 1685 1686 LogicalResult matchAndRewrite(BroadcastOp broadcastOp, 1687 PatternRewriter &rewriter) const override { 1688 auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>(); 1689 if (!srcBroadcast) 1690 return failure(); 1691 rewriter.replaceOpWithNewOp<BroadcastOp>( 1692 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source()); 1693 return success(); 1694 } 1695 }; 1696 } // namespace 1697 1698 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 1699 MLIRContext *context) { 1700 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by 1701 // calling `populateCastAwayVectorLeadingOneDimPatterns` 1702 results.add<BroadcastFolder>(context); 1703 } 1704 1705 //===----------------------------------------------------------------------===// 1706 // ShuffleOp 1707 //===----------------------------------------------------------------------===// 1708 1709 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, 1710 Value v2, ArrayRef<int64_t> mask) { 1711 build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask)); 1712 } 1713 1714 LogicalResult ShuffleOp::verify() { 1715 VectorType resultType = getVectorType(); 1716 VectorType v1Type = getV1VectorType(); 1717 VectorType v2Type = getV2VectorType(); 1718 // Verify ranks. 1719 int64_t resRank = resultType.getRank(); 1720 int64_t v1Rank = v1Type.getRank(); 1721 int64_t v2Rank = v2Type.getRank(); 1722 if (resRank != v1Rank || v1Rank != v2Rank) 1723 return emitOpError("rank mismatch"); 1724 // Verify all but leading dimension sizes. 1725 for (int64_t r = 1; r < v1Rank; ++r) { 1726 int64_t resDim = resultType.getDimSize(r); 1727 int64_t v1Dim = v1Type.getDimSize(r); 1728 int64_t v2Dim = v2Type.getDimSize(r); 1729 if (resDim != v1Dim || v1Dim != v2Dim) 1730 return emitOpError("dimension mismatch"); 1731 } 1732 // Verify mask length. 1733 auto maskAttr = mask().getValue(); 1734 int64_t maskLength = maskAttr.size(); 1735 if (maskLength <= 0) 1736 return emitOpError("invalid mask length"); 1737 if (maskLength != resultType.getDimSize(0)) 1738 return emitOpError("mask length mismatch"); 1739 // Verify all indices. 1740 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); 1741 for (const auto &en : llvm::enumerate(maskAttr)) { 1742 auto attr = en.value().dyn_cast<IntegerAttr>(); 1743 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) 1744 return emitOpError("mask index #") << (en.index() + 1) << " out of range"; 1745 } 1746 return success(); 1747 } 1748 1749 LogicalResult 1750 ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>, 1751 ValueRange operands, DictionaryAttr attributes, 1752 RegionRange, 1753 SmallVectorImpl<Type> &inferredReturnTypes) { 1754 ShuffleOp::Adaptor op(operands, attributes); 1755 auto v1Type = op.v1().getType().cast<VectorType>(); 1756 // Construct resulting type: leading dimension matches mask length, 1757 // all trailing dimensions match the operands. 1758 SmallVector<int64_t, 4> shape; 1759 shape.reserve(v1Type.getRank()); 1760 shape.push_back(std::max<size_t>(1, op.mask().size())); 1761 llvm::append_range(shape, v1Type.getShape().drop_front()); 1762 inferredReturnTypes.push_back( 1763 VectorType::get(shape, v1Type.getElementType())); 1764 return success(); 1765 } 1766 1767 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) { 1768 Attribute lhs = operands.front(), rhs = operands.back(); 1769 if (!lhs || !rhs) 1770 return {}; 1771 1772 auto lhsType = lhs.getType().cast<VectorType>(); 1773 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr 1774 // manipulation. 1775 if (lhsType.getRank() != 1) 1776 return {}; 1777 int64_t lhsSize = lhsType.getDimSize(0); 1778 1779 SmallVector<Attribute> results; 1780 auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>(); 1781 auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>(); 1782 for (const auto &index : this->mask().getAsValueRange<IntegerAttr>()) { 1783 int64_t i = index.getZExtValue(); 1784 if (i >= lhsSize) { 1785 results.push_back(rhsElements[i - lhsSize]); 1786 } else { 1787 results.push_back(lhsElements[i]); 1788 } 1789 } 1790 1791 return DenseElementsAttr::get(getVectorType(), results); 1792 } 1793 1794 //===----------------------------------------------------------------------===// 1795 // InsertElementOp 1796 //===----------------------------------------------------------------------===// 1797 1798 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1799 Value source, Value dest) { 1800 build(builder, result, source, dest, {}); 1801 } 1802 1803 LogicalResult InsertElementOp::verify() { 1804 auto dstVectorType = getDestVectorType(); 1805 if (dstVectorType.getRank() == 0) { 1806 if (position()) 1807 return emitOpError("expected position to be empty with 0-D vector"); 1808 return success(); 1809 } 1810 if (dstVectorType.getRank() != 1) 1811 return emitOpError("unexpected >1 vector rank"); 1812 if (!position()) 1813 return emitOpError("expected position for 1-D vector"); 1814 return success(); 1815 } 1816 1817 //===----------------------------------------------------------------------===// 1818 // InsertOp 1819 //===----------------------------------------------------------------------===// 1820 1821 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1822 Value dest, ArrayRef<int64_t> position) { 1823 result.addOperands({source, dest}); 1824 auto positionAttr = getVectorSubscriptAttr(builder, position); 1825 result.addTypes(dest.getType()); 1826 result.addAttribute(getPositionAttrStrName(), positionAttr); 1827 } 1828 1829 // Convenience builder which assumes the values are constant indices. 1830 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1831 Value dest, ValueRange position) { 1832 SmallVector<int64_t, 4> positionConstants = 1833 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 1834 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 1835 })); 1836 build(builder, result, source, dest, positionConstants); 1837 } 1838 1839 LogicalResult InsertOp::verify() { 1840 auto positionAttr = position().getValue(); 1841 auto destVectorType = getDestVectorType(); 1842 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank())) 1843 return emitOpError( 1844 "expected position attribute of rank smaller than dest vector rank"); 1845 auto srcVectorType = getSourceType().dyn_cast<VectorType>(); 1846 if (srcVectorType && 1847 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() != 1848 static_cast<unsigned>(destVectorType.getRank()))) 1849 return emitOpError("expected position attribute rank + source rank to " 1850 "match dest vector rank"); 1851 if (!srcVectorType && 1852 (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) 1853 return emitOpError( 1854 "expected position attribute rank to match the dest vector rank"); 1855 for (const auto &en : llvm::enumerate(positionAttr)) { 1856 auto attr = en.value().dyn_cast<IntegerAttr>(); 1857 if (!attr || attr.getInt() < 0 || 1858 attr.getInt() >= destVectorType.getDimSize(en.index())) 1859 return emitOpError("expected position attribute #") 1860 << (en.index() + 1) 1861 << " to be a non-negative integer smaller than the corresponding " 1862 "dest vector dimension"; 1863 } 1864 return success(); 1865 } 1866 1867 namespace { 1868 1869 // If insertOp is only inserting unit dimensions it can be transformed to a 1870 // broadcast. 1871 class InsertToBroadcast final : public OpRewritePattern<InsertOp> { 1872 public: 1873 using OpRewritePattern<InsertOp>::OpRewritePattern; 1874 1875 LogicalResult matchAndRewrite(InsertOp insertOp, 1876 PatternRewriter &rewriter) const override { 1877 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>(); 1878 if (!srcVecType || insertOp.getDestVectorType().getNumElements() != 1879 srcVecType.getNumElements()) 1880 return failure(); 1881 rewriter.replaceOpWithNewOp<BroadcastOp>( 1882 insertOp, insertOp.getDestVectorType(), insertOp.source()); 1883 return success(); 1884 } 1885 }; 1886 1887 } // namespace 1888 1889 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, 1890 MLIRContext *context) { 1891 results.add<InsertToBroadcast, BroadcastFolder>(context); 1892 } 1893 1894 // Eliminates insert operations that produce values identical to their source 1895 // value. This happens when the source and destination vectors have identical 1896 // sizes. 1897 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) { 1898 if (position().empty()) 1899 return source(); 1900 return {}; 1901 } 1902 1903 //===----------------------------------------------------------------------===// 1904 // InsertMapOp 1905 //===----------------------------------------------------------------------===// 1906 1907 LogicalResult InsertMapOp::verify() { 1908 if (getSourceVectorType().getRank() != getResultType().getRank()) 1909 return emitOpError("expected source and destination vectors of same rank"); 1910 unsigned numId = 0; 1911 for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) { 1912 if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) != 1913 0) 1914 return emitOpError( 1915 "destination vector size must be a multiple of source vector size"); 1916 if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i)) 1917 numId++; 1918 } 1919 if (numId != ids().size()) 1920 return emitOpError("expected number of ids must match the number of " 1921 "dimensions distributed"); 1922 return success(); 1923 } 1924 1925 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); } 1926 1927 //===----------------------------------------------------------------------===// 1928 // InsertStridedSliceOp 1929 //===----------------------------------------------------------------------===// 1930 1931 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, 1932 Value source, Value dest, 1933 ArrayRef<int64_t> offsets, 1934 ArrayRef<int64_t> strides) { 1935 result.addOperands({source, dest}); 1936 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 1937 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 1938 result.addTypes(dest.getType()); 1939 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); 1940 result.addAttribute(getStridesAttrStrName(), stridesAttr); 1941 } 1942 1943 // TODO: Should be moved to Tablegen Confined attributes. 1944 template <typename OpType> 1945 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, 1946 ArrayAttr arrayAttr, 1947 ArrayRef<int64_t> shape, 1948 StringRef attrName) { 1949 if (arrayAttr.size() > shape.size()) 1950 return op.emitOpError("expected ") 1951 << attrName << " attribute of rank smaller than vector rank"; 1952 return success(); 1953 } 1954 1955 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 1956 // interval. If `halfOpen` is true then the admissible interval is [min, max). 1957 // Otherwise, the admissible interval is [min, max]. 1958 template <typename OpType> 1959 static LogicalResult 1960 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, 1961 int64_t max, StringRef attrName, 1962 bool halfOpen = true) { 1963 for (auto attr : arrayAttr) { 1964 auto val = attr.cast<IntegerAttr>().getInt(); 1965 auto upper = max; 1966 if (!halfOpen) 1967 upper += 1; 1968 if (val < min || val >= upper) 1969 return op.emitOpError("expected ") << attrName << " to be confined to [" 1970 << min << ", " << upper << ")"; 1971 } 1972 return success(); 1973 } 1974 1975 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 1976 // interval. If `halfOpen` is true then the admissible interval is [min, max). 1977 // Otherwise, the admissible interval is [min, max]. 1978 template <typename OpType> 1979 static LogicalResult 1980 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, 1981 ArrayRef<int64_t> shape, StringRef attrName, 1982 bool halfOpen = true, int64_t min = 0) { 1983 assert(arrayAttr.size() <= shape.size()); 1984 unsigned index = 0; 1985 for (auto it : llvm::zip(arrayAttr, shape)) { 1986 auto val = std::get<0>(it).cast<IntegerAttr>().getInt(); 1987 auto max = std::get<1>(it); 1988 if (!halfOpen) 1989 max += 1; 1990 if (val < min || val >= max) 1991 return op.emitOpError("expected ") 1992 << attrName << " dimension " << index << " to be confined to [" 1993 << min << ", " << max << ")"; 1994 ++index; 1995 } 1996 return success(); 1997 } 1998 1999 // Returns true if all integers in `arrayAttr` are in the interval [min, max}. 2000 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2001 // Otherwise, the admissible interval is [min, max]. 2002 template <typename OpType> 2003 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( 2004 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, 2005 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, 2006 bool halfOpen = true, int64_t min = 1) { 2007 assert(arrayAttr1.size() <= shape.size()); 2008 assert(arrayAttr2.size() <= shape.size()); 2009 unsigned index = 0; 2010 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) { 2011 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt(); 2012 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt(); 2013 auto max = std::get<2>(it); 2014 if (!halfOpen) 2015 max += 1; 2016 if (val1 + val2 < 0 || val1 + val2 >= max) 2017 return op.emitOpError("expected sum(") 2018 << attrName1 << ", " << attrName2 << ") dimension " << index 2019 << " to be confined to [" << min << ", " << max << ")"; 2020 ++index; 2021 } 2022 return success(); 2023 } 2024 2025 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, 2026 MLIRContext *context) { 2027 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { 2028 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); 2029 }); 2030 return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); 2031 } 2032 2033 LogicalResult InsertStridedSliceOp::verify() { 2034 auto sourceVectorType = getSourceVectorType(); 2035 auto destVectorType = getDestVectorType(); 2036 auto offsets = offsetsAttr(); 2037 auto strides = stridesAttr(); 2038 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) 2039 return emitOpError( 2040 "expected offsets of same size as destination vector rank"); 2041 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) 2042 return emitOpError("expected strides of same size as source vector rank"); 2043 if (sourceVectorType.getRank() > destVectorType.getRank()) 2044 return emitOpError( 2045 "expected source rank to be smaller than destination rank"); 2046 2047 auto sourceShape = sourceVectorType.getShape(); 2048 auto destShape = destVectorType.getShape(); 2049 SmallVector<int64_t, 4> sourceShapeAsDestShape( 2050 destShape.size() - sourceShape.size(), 0); 2051 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); 2052 auto offName = InsertStridedSliceOp::getOffsetsAttrName(); 2053 auto stridesName = InsertStridedSliceOp::getStridesAttrName(); 2054 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, 2055 offName)) || 2056 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, 2057 stridesName, 2058 /*halfOpen=*/false)) || 2059 failed(isSumOfIntegerArrayAttrConfinedToShape( 2060 *this, offsets, 2061 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, 2062 offName, "source vector shape", 2063 /*halfOpen=*/false, /*min=*/1))) 2064 return failure(); 2065 2066 return success(); 2067 } 2068 2069 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2070 if (getSourceVectorType() == getDestVectorType()) 2071 return source(); 2072 return {}; 2073 } 2074 2075 //===----------------------------------------------------------------------===// 2076 // OuterProductOp 2077 //===----------------------------------------------------------------------===// 2078 2079 /// Build an op without mask, use the type of `acc` as the return type. 2080 void OuterProductOp::build(OpBuilder &builder, OperationState &result, 2081 Value lhs, Value rhs, Value acc) { 2082 result.addOperands({lhs, rhs, acc}); 2083 result.addTypes(acc.getType()); 2084 } 2085 2086 void OuterProductOp::print(OpAsmPrinter &p) { 2087 p << " " << lhs() << ", " << rhs(); 2088 if (!acc().empty()) { 2089 p << ", " << acc(); 2090 p.printOptionalAttrDict((*this)->getAttrs()); 2091 } 2092 p << " : " << lhs().getType() << ", " << rhs().getType(); 2093 } 2094 2095 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { 2096 SmallVector<OpAsmParser::OperandType, 3> operandsInfo; 2097 Type tLHS, tRHS; 2098 if (parser.parseOperandList(operandsInfo) || 2099 parser.parseOptionalAttrDict(result.attributes) || 2100 parser.parseColonType(tLHS) || parser.parseComma() || 2101 parser.parseType(tRHS)) 2102 return failure(); 2103 if (operandsInfo.size() < 2) 2104 return parser.emitError(parser.getNameLoc(), 2105 "expected at least 2 operands"); 2106 VectorType vLHS = tLHS.dyn_cast<VectorType>(); 2107 VectorType vRHS = tRHS.dyn_cast<VectorType>(); 2108 if (!vLHS) 2109 return parser.emitError(parser.getNameLoc(), 2110 "expected vector type for operand #1"); 2111 VectorType resType = 2112 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, 2113 vLHS.getElementType()) 2114 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); 2115 2116 if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { 2117 result.attributes.append( 2118 OuterProductOp::getKindAttrStrName(), 2119 CombiningKindAttr::get(OuterProductOp::getDefaultKind(), 2120 result.getContext())); 2121 } 2122 2123 return failure( 2124 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || 2125 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || 2126 (operandsInfo.size() > 2 && 2127 parser.resolveOperand(operandsInfo[2], resType, result.operands)) || 2128 parser.addTypeToList(resType, result.types)); 2129 } 2130 2131 LogicalResult OuterProductOp::verify() { 2132 Type tRHS = getOperandTypeRHS(); 2133 VectorType vLHS = getOperandVectorTypeLHS(), 2134 vRHS = tRHS.dyn_cast<VectorType>(), 2135 vACC = getOperandVectorTypeACC(), vRES = getVectorType(); 2136 2137 if (vLHS.getRank() != 1) 2138 return emitOpError("expected 1-d vector for operand #1"); 2139 2140 if (vRHS) { 2141 // Proper OUTER operation. 2142 if (vRHS.getRank() != 1) 2143 return emitOpError("expected 1-d vector for operand #2"); 2144 if (vRES.getRank() != 2) 2145 return emitOpError("expected 2-d vector result"); 2146 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2147 return emitOpError("expected #1 operand dim to match result dim #1"); 2148 if (vRHS.getDimSize(0) != vRES.getDimSize(1)) 2149 return emitOpError("expected #2 operand dim to match result dim #2"); 2150 } else { 2151 // An AXPY operation. 2152 if (vRES.getRank() != 1) 2153 return emitOpError("expected 1-d vector result"); 2154 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2155 return emitOpError("expected #1 operand dim to match result dim #1"); 2156 } 2157 2158 if (vACC && vACC != vRES) 2159 return emitOpError("expected operand #3 of same type as result type"); 2160 2161 // Verify supported combining kind. 2162 if (!isSupportedCombiningKind(kind(), vRES.getElementType())) 2163 return emitOpError("unsupported outerproduct type"); 2164 2165 return success(); 2166 } 2167 2168 //===----------------------------------------------------------------------===// 2169 // ReshapeOp 2170 //===----------------------------------------------------------------------===// 2171 2172 LogicalResult ReshapeOp::verify() { 2173 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. 2174 auto inputVectorType = getInputVectorType(); 2175 auto outputVectorType = getOutputVectorType(); 2176 int64_t inputShapeRank = getNumInputShapeSizes(); 2177 int64_t outputShapeRank = getNumOutputShapeSizes(); 2178 SmallVector<int64_t, 4> fixedVectorSizes; 2179 getFixedVectorSizes(fixedVectorSizes); 2180 int64_t numFixedVectorSizes = fixedVectorSizes.size(); 2181 2182 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) 2183 return emitError("invalid input shape for vector type ") 2184 << inputVectorType; 2185 2186 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) 2187 return emitError("invalid output shape for vector type ") 2188 << outputVectorType; 2189 2190 // Verify that the 'fixedVectorSizes' match an input/output vector shape 2191 // suffix. 2192 unsigned inputVectorRank = inputVectorType.getRank(); 2193 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2194 unsigned index = inputVectorRank - numFixedVectorSizes - i; 2195 if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) 2196 return emitError("fixed vector size must match input vector for dim ") 2197 << i; 2198 } 2199 2200 unsigned outputVectorRank = outputVectorType.getRank(); 2201 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2202 unsigned index = outputVectorRank - numFixedVectorSizes - i; 2203 if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) 2204 return emitError("fixed vector size must match output vector for dim ") 2205 << i; 2206 } 2207 2208 // If all shape operands are produced by constant ops, verify that product 2209 // of dimensions for input/output shape match. 2210 auto isDefByConstant = [](Value operand) { 2211 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 2212 }; 2213 if (llvm::all_of(input_shape(), isDefByConstant) && 2214 llvm::all_of(output_shape(), isDefByConstant)) { 2215 int64_t numInputElements = 1; 2216 for (auto operand : input_shape()) 2217 numInputElements *= 2218 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2219 int64_t numOutputElements = 1; 2220 for (auto operand : output_shape()) 2221 numOutputElements *= 2222 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2223 if (numInputElements != numOutputElements) 2224 return emitError("product of input and output shape sizes must match"); 2225 } 2226 return success(); 2227 } 2228 2229 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) { 2230 populateFromInt64AttrArray(fixed_vector_sizes(), results); 2231 } 2232 2233 //===----------------------------------------------------------------------===// 2234 // ExtractStridedSliceOp 2235 //===----------------------------------------------------------------------===// 2236 2237 // Inference works as follows: 2238 // 1. Add 'sizes' from prefix of dims in 'offsets'. 2239 // 2. Add sizes from 'vectorType' for remaining dims. 2240 static Type inferStridedSliceOpResultType(VectorType vectorType, 2241 ArrayAttr offsets, ArrayAttr sizes, 2242 ArrayAttr strides) { 2243 assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); 2244 SmallVector<int64_t, 4> shape; 2245 shape.reserve(vectorType.getRank()); 2246 unsigned idx = 0; 2247 for (unsigned e = offsets.size(); idx < e; ++idx) 2248 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt()); 2249 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) 2250 shape.push_back(vectorType.getShape()[idx]); 2251 2252 return VectorType::get(shape, vectorType.getElementType()); 2253 } 2254 2255 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, 2256 Value source, ArrayRef<int64_t> offsets, 2257 ArrayRef<int64_t> sizes, 2258 ArrayRef<int64_t> strides) { 2259 result.addOperands(source); 2260 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 2261 auto sizesAttr = getVectorSubscriptAttr(builder, sizes); 2262 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 2263 result.addTypes( 2264 inferStridedSliceOpResultType(source.getType().cast<VectorType>(), 2265 offsetsAttr, sizesAttr, stridesAttr)); 2266 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); 2267 result.addAttribute(getSizesAttrStrName(), sizesAttr); 2268 result.addAttribute(getStridesAttrStrName(), stridesAttr); 2269 } 2270 2271 LogicalResult ExtractStridedSliceOp::verify() { 2272 auto type = getVectorType(); 2273 auto offsets = offsetsAttr(); 2274 auto sizes = sizesAttr(); 2275 auto strides = stridesAttr(); 2276 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) 2277 return emitOpError("expected offsets, sizes and strides attributes of same size"); 2278 2279 auto shape = type.getShape(); 2280 auto offName = getOffsetsAttrName(); 2281 auto sizesName = getSizesAttrName(); 2282 auto stridesName = getStridesAttrName(); 2283 if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || 2284 failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || 2285 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, 2286 stridesName)) || 2287 failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || 2288 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, 2289 /*halfOpen=*/false, 2290 /*min=*/1)) || 2291 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, 2292 /*halfOpen=*/false)) || 2293 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, 2294 offName, sizesName, 2295 /*halfOpen=*/false))) 2296 return failure(); 2297 2298 auto resultType = 2299 inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); 2300 if (getResult().getType() != resultType) 2301 return emitOpError("expected result type to be ") << resultType; 2302 2303 return success(); 2304 } 2305 2306 // When the source of ExtractStrided comes from a chain of InsertStrided ops try 2307 // to use the source of the InsertStrided ops if we can detect that the 2308 // extracted vector is a subset of one of the vector inserted. 2309 static LogicalResult 2310 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { 2311 // Helper to extract integer out of ArrayAttr. 2312 auto getElement = [](ArrayAttr array, int idx) { 2313 return array[idx].cast<IntegerAttr>().getInt(); 2314 }; 2315 ArrayAttr extractOffsets = op.offsets(); 2316 ArrayAttr extractStrides = op.strides(); 2317 ArrayAttr extractSizes = op.sizes(); 2318 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 2319 while (insertOp) { 2320 if (op.getVectorType().getRank() != 2321 insertOp.getSourceVectorType().getRank()) 2322 return failure(); 2323 ArrayAttr insertOffsets = insertOp.offsets(); 2324 ArrayAttr insertStrides = insertOp.strides(); 2325 // If the rank of extract is greater than the rank of insert, we are likely 2326 // extracting a partial chunk of the vector inserted. 2327 if (extractOffsets.size() > insertOffsets.size()) 2328 return failure(); 2329 bool patialoverlap = false; 2330 bool disjoint = false; 2331 SmallVector<int64_t, 4> offsetDiffs; 2332 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 2333 if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) 2334 return failure(); 2335 int64_t start = getElement(insertOffsets, dim); 2336 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); 2337 int64_t offset = getElement(extractOffsets, dim); 2338 int64_t size = getElement(extractSizes, dim); 2339 // Check if the start of the extract offset is in the interval inserted. 2340 if (start <= offset && offset < end) { 2341 // If the extract interval overlaps but is not fully included we may 2342 // have a partial overlap that will prevent any folding. 2343 if (offset + size > end) 2344 patialoverlap = true; 2345 offsetDiffs.push_back(offset - start); 2346 continue; 2347 } 2348 disjoint = true; 2349 break; 2350 } 2351 // The extract element chunk is a subset of the insert element. 2352 if (!disjoint && !patialoverlap) { 2353 op.setOperand(insertOp.source()); 2354 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2355 OpBuilder b(op.getContext()); 2356 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), 2357 b.getI64ArrayAttr(offsetDiffs)); 2358 return success(); 2359 } 2360 // If the chunk extracted is disjoint from the chunk inserted, keep looking 2361 // in the insert chain. 2362 if (disjoint) 2363 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 2364 else { 2365 // The extracted vector partially overlap the inserted vector, we cannot 2366 // fold. 2367 return failure(); 2368 } 2369 } 2370 return failure(); 2371 } 2372 2373 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2374 if (getVectorType() == getResult().getType()) 2375 return vector(); 2376 if (succeeded(foldExtractStridedOpFromInsertChain(*this))) 2377 return getResult(); 2378 return {}; 2379 } 2380 2381 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { 2382 populateFromInt64AttrArray(offsets(), results); 2383 } 2384 2385 namespace { 2386 2387 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to 2388 // ConstantMaskOp. 2389 class StridedSliceConstantMaskFolder final 2390 : public OpRewritePattern<ExtractStridedSliceOp> { 2391 public: 2392 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2393 2394 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2395 PatternRewriter &rewriter) const override { 2396 // Return if 'extractStridedSliceOp' operand is not defined by a 2397 // ConstantMaskOp. 2398 auto *defOp = extractStridedSliceOp.vector().getDefiningOp(); 2399 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); 2400 if (!constantMaskOp) 2401 return failure(); 2402 // Return if 'extractStridedSliceOp' has non-unit strides. 2403 if (extractStridedSliceOp.hasNonUnitStrides()) 2404 return failure(); 2405 // Gather constant mask dimension sizes. 2406 SmallVector<int64_t, 4> maskDimSizes; 2407 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); 2408 // Gather strided slice offsets and sizes. 2409 SmallVector<int64_t, 4> sliceOffsets; 2410 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets); 2411 SmallVector<int64_t, 4> sliceSizes; 2412 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes); 2413 2414 // Compute slice of vector mask region. 2415 SmallVector<int64_t, 4> sliceMaskDimSizes; 2416 assert(sliceOffsets.size() == maskDimSizes.size()); 2417 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { 2418 int64_t maskDimSize = std::get<0>(it); 2419 int64_t sliceOffset = std::get<1>(it); 2420 int64_t sliceSize = std::get<2>(it); 2421 int64_t sliceMaskDimSize = std::max( 2422 static_cast<int64_t>(0), 2423 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); 2424 sliceMaskDimSizes.push_back(sliceMaskDimSize); 2425 } 2426 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked 2427 // region is a conjunction of mask dim intervals). 2428 if (llvm::is_contained(sliceMaskDimSizes, 0)) 2429 sliceMaskDimSizes.assign(maskDimSizes.size(), 0); 2430 2431 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask 2432 // region. 2433 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 2434 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), 2435 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); 2436 return success(); 2437 } 2438 }; 2439 2440 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. 2441 class StridedSliceConstantFolder final 2442 : public OpRewritePattern<ExtractStridedSliceOp> { 2443 public: 2444 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2445 2446 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2447 PatternRewriter &rewriter) const override { 2448 // Return if 'extractStridedSliceOp' operand is not defined by a 2449 // ConstantOp. 2450 auto constantOp = 2451 extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>(); 2452 if (!constantOp) 2453 return failure(); 2454 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 2455 if (!dense) 2456 return failure(); 2457 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), 2458 dense.getSplatValue<Attribute>()); 2459 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, 2460 newAttr); 2461 return success(); 2462 } 2463 }; 2464 2465 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to 2466 // BroadcastOp(ExtractStrideSliceOp). 2467 class StridedSliceBroadcast final 2468 : public OpRewritePattern<ExtractStridedSliceOp> { 2469 public: 2470 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2471 2472 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2473 PatternRewriter &rewriter) const override { 2474 auto broadcast = op.vector().getDefiningOp<BroadcastOp>(); 2475 if (!broadcast) 2476 return failure(); 2477 auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>(); 2478 unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; 2479 auto dstVecType = op.getType().cast<VectorType>(); 2480 unsigned dstRank = dstVecType.getRank(); 2481 unsigned rankDiff = dstRank - srcRrank; 2482 // Check if the most inner dimensions of the source of the broadcast are the 2483 // same as the destination of the extract. If this is the case we can just 2484 // use a broadcast as the original dimensions are untouched. 2485 bool lowerDimMatch = true; 2486 for (unsigned i = 0; i < srcRrank; i++) { 2487 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { 2488 lowerDimMatch = false; 2489 break; 2490 } 2491 } 2492 Value source = broadcast.source(); 2493 if (!lowerDimMatch) { 2494 // The inner dimensions don't match, it means we need to extract from the 2495 // source of the orignal broadcast and then broadcast the extracted value. 2496 source = rewriter.create<ExtractStridedSliceOp>( 2497 op->getLoc(), source, 2498 getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), 2499 getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), 2500 getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); 2501 } 2502 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); 2503 return success(); 2504 } 2505 }; 2506 2507 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. 2508 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { 2509 public: 2510 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2511 2512 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2513 PatternRewriter &rewriter) const override { 2514 auto splat = op.vector().getDefiningOp<SplatOp>(); 2515 if (!splat) 2516 return failure(); 2517 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input()); 2518 return success(); 2519 } 2520 }; 2521 2522 } // namespace 2523 2524 void ExtractStridedSliceOp::getCanonicalizationPatterns( 2525 RewritePatternSet &results, MLIRContext *context) { 2526 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> 2527 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. 2528 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder, 2529 StridedSliceBroadcast, StridedSliceSplat>(context); 2530 } 2531 2532 //===----------------------------------------------------------------------===// 2533 // TransferReadOp 2534 //===----------------------------------------------------------------------===// 2535 2536 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). 2537 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2538 VectorType vectorType, Value source, 2539 ValueRange indices, AffineMapAttr permutationMapAttr, 2540 /*optional*/ ArrayAttr inBoundsAttr) { 2541 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2542 Value padding = builder.create<arith::ConstantOp>( 2543 result.location, elemType, builder.getZeroAttr(elemType)); 2544 build(builder, result, vectorType, source, indices, permutationMapAttr, 2545 padding, /*mask=*/Value(), inBoundsAttr); 2546 } 2547 2548 /// 2. Builder that sets padding to zero an empty mask (variant without attrs). 2549 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2550 VectorType vectorType, Value source, 2551 ValueRange indices, AffineMap permutationMap, 2552 Optional<ArrayRef<bool>> inBounds) { 2553 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2554 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2555 ? builder.getBoolArrayAttr(inBounds.getValue()) 2556 : ArrayAttr(); 2557 build(builder, result, vectorType, source, indices, permutationMapAttr, 2558 inBoundsAttr); 2559 } 2560 2561 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. 2562 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2563 VectorType vectorType, Value source, 2564 ValueRange indices, Value padding, 2565 Optional<ArrayRef<bool>> inBounds) { 2566 AffineMap permutationMap = getTransferMinorIdentityMap( 2567 source.getType().cast<ShapedType>(), vectorType); 2568 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2569 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2570 ? builder.getBoolArrayAttr(inBounds.getValue()) 2571 : ArrayAttr(); 2572 build(builder, result, vectorType, source, indices, permutationMapAttr, 2573 padding, 2574 /*mask=*/Value(), inBoundsAttr); 2575 } 2576 2577 /// 4. Builder that sets padding to zero and permutation map to 2578 /// 'getMinorIdentityMap'. 2579 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2580 VectorType vectorType, Value source, 2581 ValueRange indices, 2582 Optional<ArrayRef<bool>> inBounds) { 2583 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2584 Value padding = builder.create<arith::ConstantOp>( 2585 result.location, elemType, builder.getZeroAttr(elemType)); 2586 build(builder, result, vectorType, source, indices, padding, inBounds); 2587 } 2588 2589 template <typename EmitFun> 2590 static LogicalResult verifyPermutationMap(AffineMap permutationMap, 2591 EmitFun emitOpError) { 2592 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); 2593 for (auto expr : permutationMap.getResults()) { 2594 auto dim = expr.dyn_cast<AffineDimExpr>(); 2595 auto zero = expr.dyn_cast<AffineConstantExpr>(); 2596 if (zero) { 2597 if (zero.getValue() != 0) { 2598 return emitOpError( 2599 "requires a projected permutation_map (at most one dim or the zero " 2600 "constant can appear in each result)"); 2601 } 2602 continue; 2603 } 2604 if (!dim) { 2605 return emitOpError("requires a projected permutation_map (at most one " 2606 "dim or the zero constant can appear in each result)"); 2607 } 2608 if (seen[dim.getPosition()]) { 2609 return emitOpError( 2610 "requires a permutation_map that is a permutation (found one dim " 2611 "used more than once)"); 2612 } 2613 seen[dim.getPosition()] = true; 2614 } 2615 return success(); 2616 } 2617 2618 static LogicalResult 2619 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, 2620 VectorType vectorType, VectorType maskType, 2621 AffineMap permutationMap, ArrayAttr inBounds) { 2622 if (op->hasAttr("masked")) { 2623 return op->emitOpError("masked attribute has been removed. " 2624 "Use in_bounds instead."); 2625 } 2626 2627 if (!shapedType.isa<MemRefType, RankedTensorType>()) 2628 return op->emitOpError( 2629 "requires source to be a memref or ranked tensor type"); 2630 2631 auto elementType = shapedType.getElementType(); 2632 DataLayout dataLayout = DataLayout::closest(op); 2633 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) { 2634 // Memref or tensor has vector element type. 2635 unsigned sourceVecSize = 2636 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * 2637 vectorElementType.getShape().back(); 2638 unsigned resultVecSize = 2639 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * 2640 vectorType.getShape().back(); 2641 if (resultVecSize % sourceVecSize != 0) 2642 return op->emitOpError( 2643 "requires the bitwidth of the minor 1-D vector to be an integral " 2644 "multiple of the bitwidth of the minor 1-D vector of the source"); 2645 2646 unsigned sourceVecEltRank = vectorElementType.getRank(); 2647 unsigned resultVecRank = vectorType.getRank(); 2648 if (sourceVecEltRank > resultVecRank) 2649 return op->emitOpError( 2650 "requires source vector element and vector result ranks to match."); 2651 unsigned rankOffset = resultVecRank - sourceVecEltRank; 2652 // Check that permutation map results match 'rankOffset' of vector type. 2653 if (permutationMap.getNumResults() != rankOffset) 2654 return op->emitOpError("requires a permutation_map with result dims of " 2655 "the same rank as the vector type"); 2656 2657 if (maskType) 2658 return op->emitOpError("does not support masks with vector element type"); 2659 } else { 2660 // Memref or tensor has scalar element type. 2661 unsigned minorSize = 2662 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); 2663 unsigned resultVecSize = 2664 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; 2665 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) 2666 return op->emitOpError( 2667 "requires the bitwidth of the minor 1-D vector to be an integral " 2668 "multiple of the bitwidth of the source element type"); 2669 2670 // Check that permutation map results match rank of vector type. 2671 if (permutationMap.getNumResults() != vectorType.getRank()) 2672 return op->emitOpError("requires a permutation_map with result dims of " 2673 "the same rank as the vector type"); 2674 2675 VectorType expectedMaskType = 2676 vector::detail::transferMaskType(vectorType, permutationMap); 2677 if (maskType && expectedMaskType != maskType) 2678 return op->emitOpError("expects mask type consistent with permutation " 2679 "map: ") 2680 << maskType; 2681 } 2682 2683 if (permutationMap.getNumSymbols() != 0) 2684 return op->emitOpError("requires permutation_map without symbols"); 2685 2686 if (permutationMap.getNumInputs() != shapedType.getRank()) 2687 return op->emitOpError("requires a permutation_map with input dims of the " 2688 "same rank as the source type"); 2689 2690 if (inBounds) { 2691 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size())) 2692 return op->emitOpError("expects the optional in_bounds attr of same rank " 2693 "as permutation_map results: ") 2694 << AffineMapAttr::get(permutationMap) 2695 << " vs inBounds of size: " << inBounds.size(); 2696 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) 2697 if (permutationMap.getResult(i).isa<AffineConstantExpr>() && 2698 !inBounds.getValue()[i].cast<BoolAttr>().getValue()) 2699 return op->emitOpError("requires broadcast dimensions to be in-bounds"); 2700 } 2701 2702 return success(); 2703 } 2704 2705 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { 2706 SmallVector<StringRef, 3> elidedAttrs; 2707 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); 2708 if (op.permutation_map().isMinorIdentity()) 2709 elidedAttrs.push_back(op.getPermutationMapAttrStrName()); 2710 bool elideInBounds = true; 2711 if (auto inBounds = op.in_bounds()) { 2712 for (auto attr : *inBounds) { 2713 if (attr.template cast<BoolAttr>().getValue()) { 2714 elideInBounds = false; 2715 break; 2716 } 2717 } 2718 } 2719 if (elideInBounds) 2720 elidedAttrs.push_back(op.getInBoundsAttrStrName()); 2721 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 2722 } 2723 2724 void TransferReadOp::print(OpAsmPrinter &p) { 2725 p << " " << source() << "[" << indices() << "], " << padding(); 2726 if (mask()) 2727 p << ", " << mask(); 2728 printTransferAttrs(p, *this); 2729 p << " : " << getShapedType() << ", " << getVectorType(); 2730 } 2731 2732 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { 2733 auto &builder = parser.getBuilder(); 2734 SMLoc typesLoc; 2735 OpAsmParser::OperandType sourceInfo; 2736 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 2737 OpAsmParser::OperandType paddingInfo; 2738 SmallVector<Type, 2> types; 2739 OpAsmParser::OperandType maskInfo; 2740 // Parsing with support for paddingValue. 2741 if (parser.parseOperand(sourceInfo) || 2742 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 2743 parser.parseComma() || parser.parseOperand(paddingInfo)) 2744 return failure(); 2745 ParseResult hasMask = parser.parseOptionalComma(); 2746 if (hasMask.succeeded()) { 2747 parser.parseOperand(maskInfo); 2748 } 2749 if (parser.parseOptionalAttrDict(result.attributes) || 2750 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 2751 return failure(); 2752 if (types.size() != 2) 2753 return parser.emitError(typesLoc, "requires two types"); 2754 auto indexType = builder.getIndexType(); 2755 auto shapedType = types[0].dyn_cast<ShapedType>(); 2756 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 2757 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 2758 VectorType vectorType = types[1].dyn_cast<VectorType>(); 2759 if (!vectorType) 2760 return parser.emitError(typesLoc, "requires vector type"); 2761 auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName(); 2762 Attribute mapAttr = result.attributes.get(permutationAttrName); 2763 if (!mapAttr) { 2764 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 2765 // Update `mapAttr` that is used later to determine mask type. 2766 mapAttr = AffineMapAttr::get(permMap); 2767 result.attributes.set(permutationAttrName, mapAttr); 2768 } 2769 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || 2770 parser.resolveOperands(indexInfo, indexType, result.operands) || 2771 parser.resolveOperand(paddingInfo, shapedType.getElementType(), 2772 result.operands)) 2773 return failure(); 2774 if (hasMask.succeeded()) { 2775 if (shapedType.getElementType().dyn_cast<VectorType>()) 2776 return parser.emitError( 2777 maskInfo.location, "does not support masks with vector element type"); 2778 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue(); 2779 // Instead of adding the mask type as an op type, compute it based on the 2780 // vector type and the permutation map (to keep the type signature small). 2781 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); 2782 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 2783 return failure(); 2784 } 2785 result.addAttribute( 2786 TransferReadOp::getOperandSegmentSizeAttr(), 2787 builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1, 2788 static_cast<int32_t>(hasMask.succeeded())})); 2789 return parser.addTypeToList(vectorType, result.types); 2790 } 2791 2792 LogicalResult TransferReadOp::verify() { 2793 // Consistency of elemental types in source and vector. 2794 ShapedType shapedType = getShapedType(); 2795 VectorType vectorType = getVectorType(); 2796 VectorType maskType = getMaskType(); 2797 auto paddingType = padding().getType(); 2798 auto permutationMap = permutation_map(); 2799 auto sourceElementType = shapedType.getElementType(); 2800 2801 if (static_cast<int64_t>(indices().size()) != shapedType.getRank()) 2802 return emitOpError("requires ") << shapedType.getRank() << " indices"; 2803 2804 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 2805 shapedType, vectorType, maskType, permutationMap, 2806 in_bounds() ? *in_bounds() : ArrayAttr()))) 2807 return failure(); 2808 2809 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) { 2810 // Source has vector element type. 2811 // Check that 'sourceVectorElementType' and 'paddingType' types match. 2812 if (sourceVectorElementType != paddingType) 2813 return emitOpError( 2814 "requires source element type and padding type to match."); 2815 2816 } else { 2817 // Check that 'paddingType' is valid to store in a vector type. 2818 if (!VectorType::isValidElementType(paddingType)) 2819 return emitOpError("requires valid padding vector elemental type"); 2820 2821 // Check that padding type and vector element types match. 2822 if (paddingType != sourceElementType) 2823 return emitOpError( 2824 "requires formal padding and source of the same elemental type"); 2825 } 2826 2827 return verifyPermutationMap(permutationMap, 2828 [&](Twine t) { return emitOpError(t); }); 2829 } 2830 2831 /// This is a common class used for patterns of the form 2832 /// ``` 2833 /// someop(memrefcast) -> someop 2834 /// ``` 2835 /// It folds the source of the memref.cast into the root operation directly. 2836 static LogicalResult foldMemRefCast(Operation *op) { 2837 bool folded = false; 2838 for (OpOperand &operand : op->getOpOperands()) { 2839 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 2840 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 2841 operand.set(castOp.getOperand()); 2842 folded = true; 2843 } 2844 } 2845 return success(folded); 2846 } 2847 2848 static LogicalResult foldTensorCast(Operation *op) { 2849 bool folded = false; 2850 for (OpOperand &operand : op->getOpOperands()) { 2851 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 2852 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 2853 operand.set(castOp.getOperand()); 2854 folded = true; 2855 } 2856 } 2857 return success(folded); 2858 } 2859 2860 template <typename TransferOp> 2861 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { 2862 // TODO: support more aggressive createOrFold on: 2863 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` 2864 if (op.getShapedType().isDynamicDim(indicesIdx)) 2865 return false; 2866 Value index = op.indices()[indicesIdx]; 2867 auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>(); 2868 if (!cstOp) 2869 return false; 2870 2871 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); 2872 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); 2873 2874 return cstOp.value() + vectorSize <= sourceSize; 2875 } 2876 2877 template <typename TransferOp> 2878 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { 2879 // TODO: support 0-d corner case. 2880 // TODO: Be less conservative. 2881 if (op.getTransferRank() == 0) 2882 return failure(); 2883 AffineMap permutationMap = op.permutation_map(); 2884 bool changed = false; 2885 SmallVector<bool, 4> newInBounds; 2886 newInBounds.reserve(op.getTransferRank()); 2887 for (unsigned i = 0; i < op.getTransferRank(); ++i) { 2888 // Already marked as in-bounds, nothing to see here. 2889 if (op.isDimInBounds(i)) { 2890 newInBounds.push_back(true); 2891 continue; 2892 } 2893 // Currently out-of-bounds, check whether we can statically determine it is 2894 // inBounds. 2895 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>(); 2896 assert(dimExpr && "Broadcast dims must be in-bounds"); 2897 auto inBounds = 2898 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); 2899 newInBounds.push_back(inBounds); 2900 // We commit the pattern if it is "more inbounds". 2901 changed |= inBounds; 2902 } 2903 if (!changed) 2904 return failure(); 2905 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2906 OpBuilder b(op.getContext()); 2907 op->setAttr(TransferOp::getInBoundsAttrStrName(), 2908 b.getBoolArrayAttr(newInBounds)); 2909 return success(); 2910 } 2911 2912 /// ``` 2913 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 2914 /// : vector<1x4xf32>, tensor<4x4xf32> 2915 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} 2916 /// : tensor<4x4xf32>, vector<1x4xf32> 2917 /// ``` 2918 /// -> Folds into 2919 /// ``` 2920 /// %v0 2921 /// ``` 2922 static Value foldRAW(TransferReadOp readOp) { 2923 if (!readOp.getShapedType().isa<RankedTensorType>()) 2924 return {}; 2925 auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>(); 2926 while (defWrite) { 2927 if (checkSameValueRAW(defWrite, readOp)) 2928 return defWrite.vector(); 2929 if (!isDisjointTransferIndices( 2930 cast<VectorTransferOpInterface>(defWrite.getOperation()), 2931 cast<VectorTransferOpInterface>(readOp.getOperation()))) 2932 break; 2933 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 2934 } 2935 return {}; 2936 } 2937 2938 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { 2939 if (Value vec = foldRAW(*this)) 2940 return vec; 2941 /// transfer_read(memrefcast) -> transfer_read 2942 if (succeeded(foldTransferInBoundsAttribute(*this))) 2943 return getResult(); 2944 if (succeeded(foldMemRefCast(*this))) 2945 return getResult(); 2946 if (succeeded(foldTensorCast(*this))) 2947 return getResult(); 2948 return OpFoldResult(); 2949 } 2950 2951 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { 2952 return llvm::to_vector<4>(getVectorType().getShape()); 2953 } 2954 2955 void TransferReadOp::getEffects( 2956 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 2957 &effects) { 2958 if (getShapedType().isa<MemRefType>()) 2959 effects.emplace_back(MemoryEffects::Read::get(), source(), 2960 SideEffects::DefaultResource::get()); 2961 } 2962 2963 namespace { 2964 /// Fold transfer_reads of a tensor.extract_slice op. E.g.: 2965 /// 2966 /// ``` 2967 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] 2968 /// : tensor<?x?xf32> to tensor<?x?xf32> 2969 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} 2970 /// : tensor<?x?xf32>, vector<4x5xf32> 2971 /// ``` 2972 /// is rewritten to: 2973 /// ``` 2974 /// %p0 = arith.addi %a, %e : index 2975 /// %p1 = arith.addi %b, %f : index 2976 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} 2977 /// : tensor<?x?xf32>, vector<4x5xf32> 2978 /// ``` 2979 struct FoldExtractSliceIntoTransferRead 2980 : public OpRewritePattern<TransferReadOp> { 2981 public: 2982 using OpRewritePattern<TransferReadOp>::OpRewritePattern; 2983 2984 LogicalResult matchAndRewrite(TransferReadOp xferOp, 2985 PatternRewriter &rewriter) const override { 2986 // TODO: support 0-d corner case. 2987 if (xferOp.getTransferRank() == 0) 2988 return failure(); 2989 if (xferOp.hasOutOfBoundsDim()) 2990 return failure(); 2991 if (!xferOp.permutation_map().isIdentity()) 2992 return failure(); 2993 if (xferOp.mask()) 2994 return failure(); 2995 auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>(); 2996 if (!extractOp) 2997 return failure(); 2998 if (!extractOp.hasUnitStride()) 2999 return failure(); 3000 3001 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3002 // dims are exactly the leading dims. I.e. the following is illegal: 3003 // ``` 3004 // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : 3005 // tensor<2x1x4xf32> to tensor<2x4xf32> 3006 // %1 = vector.transfer_read %0[0,0], %cst : 3007 // tensor<2x4xf32>, vector<2x4xf32> 3008 // ``` 3009 // 3010 // Cannot fold into: 3011 // ``` 3012 // %0 = vector.transfer_read %t[0,0,0], %cst : 3013 // tensor<2x1x4xf32>, vector<2x4xf32> 3014 // ``` 3015 // For this, check the trailing `vectorRank` dims of the extract_slice 3016 // result tensor match the trailing dims of the inferred result tensor. 3017 int64_t rankReduced = 3018 extractOp.getSourceType().getRank() - extractOp.getType().getRank(); 3019 int64_t vectorRank = xferOp.getVectorType().getRank(); 3020 RankedTensorType inferredDestTensorType = 3021 tensor::ExtractSliceOp::inferResultType( 3022 extractOp.getSourceType(), extractOp.getMixedOffsets(), 3023 extractOp.getMixedSizes(), extractOp.getMixedStrides()); 3024 auto actualDestTensorShape = extractOp.getType().getShape(); 3025 if (rankReduced > 0 && 3026 actualDestTensorShape.take_back(vectorRank) != 3027 inferredDestTensorType.getShape().take_back(vectorRank)) 3028 return failure(); 3029 3030 SmallVector<Value> newIndices; 3031 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced 3032 // indices first. 3033 for (int64_t i = 0; i < rankReduced; ++i) { 3034 OpFoldResult offset = extractOp.getMixedOffsets()[i]; 3035 newIndices.push_back(getValueOrCreateConstantIndexOp( 3036 rewriter, extractOp.getLoc(), offset)); 3037 } 3038 for (const auto &it : llvm::enumerate(xferOp.indices())) { 3039 OpFoldResult offset = 3040 extractOp.getMixedOffsets()[it.index() + rankReduced]; 3041 newIndices.push_back(rewriter.create<arith::AddIOp>( 3042 xferOp->getLoc(), it.value(), 3043 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), 3044 offset))); 3045 } 3046 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3047 rewriter.replaceOpWithNewOp<TransferReadOp>( 3048 xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, 3049 xferOp.padding(), ArrayRef<bool>{inBounds}); 3050 3051 return success(); 3052 } 3053 }; 3054 } // namespace 3055 3056 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3057 MLIRContext *context) { 3058 results.add<FoldExtractSliceIntoTransferRead>(context); 3059 } 3060 3061 //===----------------------------------------------------------------------===// 3062 // TransferWriteOp 3063 //===----------------------------------------------------------------------===// 3064 3065 /// 1. Builder with type inference. 3066 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3067 Value vector, Value dest, ValueRange indices, 3068 AffineMapAttr permutationMapAttr, 3069 /*optional*/ Value mask, 3070 /*optional*/ ArrayAttr inBoundsAttr) { 3071 Type resultType = dest.getType().dyn_cast<RankedTensorType>(); 3072 build(builder, result, resultType, vector, dest, indices, permutationMapAttr, 3073 mask, inBoundsAttr); 3074 } 3075 3076 /// 2. Builder with type inference that sets an empty mask (variant with attrs). 3077 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3078 Value vector, Value dest, ValueRange indices, 3079 AffineMapAttr permutationMapAttr, 3080 /*optional*/ ArrayAttr inBoundsAttr) { 3081 build(builder, result, vector, dest, indices, permutationMapAttr, 3082 /*mask=*/Value(), inBoundsAttr); 3083 } 3084 3085 /// 3. Builder with type inference that sets an empty mask (variant without 3086 /// attrs) 3087 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3088 Value vector, Value dest, ValueRange indices, 3089 AffineMap permutationMap, 3090 Optional<ArrayRef<bool>> inBounds) { 3091 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 3092 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 3093 ? builder.getBoolArrayAttr(inBounds.getValue()) 3094 : ArrayAttr(); 3095 build(builder, result, vector, dest, indices, permutationMapAttr, 3096 /*mask=*/Value(), inBoundsAttr); 3097 } 3098 3099 /// 4. Builder with type inference that sets an empty mask and sets permutation 3100 /// map to 'getMinorIdentityMap'. 3101 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3102 Value vector, Value dest, ValueRange indices, 3103 Optional<ArrayRef<bool>> inBounds) { 3104 auto vectorType = vector.getType().cast<VectorType>(); 3105 AffineMap permutationMap = getTransferMinorIdentityMap( 3106 dest.getType().cast<ShapedType>(), vectorType); 3107 build(builder, result, vector, dest, indices, permutationMap, inBounds); 3108 } 3109 3110 ParseResult TransferWriteOp::parse(OpAsmParser &parser, 3111 OperationState &result) { 3112 auto &builder = parser.getBuilder(); 3113 SMLoc typesLoc; 3114 OpAsmParser::OperandType vectorInfo, sourceInfo; 3115 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 3116 SmallVector<Type, 2> types; 3117 OpAsmParser::OperandType maskInfo; 3118 if (parser.parseOperand(vectorInfo) || parser.parseComma() || 3119 parser.parseOperand(sourceInfo) || 3120 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) 3121 return failure(); 3122 ParseResult hasMask = parser.parseOptionalComma(); 3123 if (hasMask.succeeded() && parser.parseOperand(maskInfo)) 3124 return failure(); 3125 if (parser.parseOptionalAttrDict(result.attributes) || 3126 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 3127 return failure(); 3128 if (types.size() != 2) 3129 return parser.emitError(typesLoc, "requires two types"); 3130 auto indexType = builder.getIndexType(); 3131 VectorType vectorType = types[0].dyn_cast<VectorType>(); 3132 if (!vectorType) 3133 return parser.emitError(typesLoc, "requires vector type"); 3134 ShapedType shapedType = types[1].dyn_cast<ShapedType>(); 3135 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 3136 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 3137 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName(); 3138 auto attr = result.attributes.get(permutationAttrName); 3139 if (!attr) { 3140 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 3141 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); 3142 } 3143 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || 3144 parser.resolveOperand(sourceInfo, shapedType, result.operands) || 3145 parser.resolveOperands(indexInfo, indexType, result.operands)) 3146 return failure(); 3147 if (hasMask.succeeded()) { 3148 if (shapedType.getElementType().dyn_cast<VectorType>()) 3149 return parser.emitError( 3150 maskInfo.location, "does not support masks with vector element type"); 3151 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); 3152 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 3153 return failure(); 3154 } 3155 result.addAttribute( 3156 TransferWriteOp::getOperandSegmentSizeAttr(), 3157 builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()), 3158 static_cast<int32_t>(hasMask.succeeded())})); 3159 return failure(shapedType.isa<RankedTensorType>() && 3160 parser.addTypeToList(shapedType, result.types)); 3161 } 3162 3163 void TransferWriteOp::print(OpAsmPrinter &p) { 3164 p << " " << vector() << ", " << source() << "[" << indices() << "]"; 3165 if (mask()) 3166 p << ", " << mask(); 3167 printTransferAttrs(p, *this); 3168 p << " : " << getVectorType() << ", " << getShapedType(); 3169 } 3170 3171 LogicalResult TransferWriteOp::verify() { 3172 // Consistency of elemental types in shape and vector. 3173 ShapedType shapedType = getShapedType(); 3174 VectorType vectorType = getVectorType(); 3175 VectorType maskType = getMaskType(); 3176 auto permutationMap = permutation_map(); 3177 3178 if (llvm::size(indices()) != shapedType.getRank()) 3179 return emitOpError("requires ") << shapedType.getRank() << " indices"; 3180 3181 // We do not allow broadcast dimensions on TransferWriteOps for the moment, 3182 // as the semantics is unclear. This can be revisited later if necessary. 3183 if (hasBroadcastDim()) 3184 return emitOpError("should not have broadcast dimensions"); 3185 3186 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 3187 shapedType, vectorType, maskType, permutationMap, 3188 in_bounds() ? *in_bounds() : ArrayAttr()))) 3189 return failure(); 3190 3191 return verifyPermutationMap(permutationMap, 3192 [&](Twine t) { return emitOpError(t); }); 3193 } 3194 3195 /// Fold: 3196 /// ``` 3197 /// %t1 = ... 3198 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : 3199 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3200 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} : 3201 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3202 /// ``` 3203 /// 3204 /// into: 3205 /// 3206 /// ``` 3207 /// %t0 3208 /// ``` 3209 /// 3210 /// The producer of t1 may or may not be DCE'd depending on whether it is a 3211 /// block argument or has side effects. 3212 static LogicalResult foldReadInitWrite(TransferWriteOp write, 3213 ArrayRef<Attribute>, 3214 SmallVectorImpl<OpFoldResult> &results) { 3215 // TODO: support 0-d corner case. 3216 if (write.getTransferRank() == 0) 3217 return failure(); 3218 auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>(); 3219 // If not operating on tensors, bail. 3220 if (!rankedTensorType) 3221 return failure(); 3222 // If no read, bail. 3223 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3224 if (!read) 3225 return failure(); 3226 // TODO: support 0-d corner case. 3227 if (read.getTransferRank() == 0) 3228 return failure(); 3229 // For now, only accept minor identity. Future: composition is minor identity. 3230 if (!read.permutation_map().isMinorIdentity() || 3231 !write.permutation_map().isMinorIdentity()) 3232 return failure(); 3233 // Bail on mismatching ranks. 3234 if (read.getTransferRank() != write.getTransferRank()) 3235 return failure(); 3236 // Bail on potential out-of-bounds accesses. 3237 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) 3238 return failure(); 3239 // Tensor types must be the same. 3240 if (read.source().getType() != rankedTensorType) 3241 return failure(); 3242 // Vector types must be the same. 3243 if (read.getVectorType() != write.getVectorType()) 3244 return failure(); 3245 // Vector and Tensor shapes must match. 3246 if (read.getVectorType().getShape() != rankedTensorType.getShape()) 3247 return failure(); 3248 // If any index is nonzero. 3249 auto isNotConstantZero = [](Value v) { 3250 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>(); 3251 return !cstOp || cstOp.value() != 0; 3252 }; 3253 if (llvm::any_of(read.indices(), isNotConstantZero) || 3254 llvm::any_of(write.indices(), isNotConstantZero)) 3255 return failure(); 3256 // Success. 3257 results.push_back(read.source()); 3258 return success(); 3259 } 3260 3261 static bool checkSameValueWAR(vector::TransferReadOp read, 3262 vector::TransferWriteOp write) { 3263 return read.source() == write.source() && read.indices() == write.indices() && 3264 read.permutation_map() == write.permutation_map() && 3265 read.getVectorType() == write.getVectorType() && !read.mask() && 3266 !write.mask(); 3267 } 3268 /// Fold transfer_write write after read: 3269 /// ``` 3270 /// %t0 = ... 3271 /// %v = vector.transfer_read %t0[%c0...] : 3272 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3273 /// %t1 = vector.transfer_write %v, %t0[%c0...] : 3274 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3275 /// ``` 3276 /// 3277 /// into: 3278 /// 3279 /// ``` 3280 /// %t0 3281 /// ``` 3282 static LogicalResult foldWAR(TransferWriteOp write, 3283 SmallVectorImpl<OpFoldResult> &results) { 3284 if (!write.source().getType().isa<RankedTensorType>()) 3285 return failure(); 3286 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3287 if (!read) 3288 return failure(); 3289 3290 if (!checkSameValueWAR(read, write)) 3291 return failure(); 3292 results.push_back(read.source()); 3293 return success(); 3294 } 3295 3296 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands, 3297 SmallVectorImpl<OpFoldResult> &results) { 3298 if (succeeded(foldReadInitWrite(*this, operands, results))) 3299 return success(); 3300 if (succeeded(foldWAR(*this, results))) 3301 return success(); 3302 if (succeeded(foldTransferInBoundsAttribute(*this))) 3303 return success(); 3304 return foldMemRefCast(*this); 3305 } 3306 3307 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { 3308 return llvm::to_vector<4>(getVectorType().getShape()); 3309 } 3310 3311 void TransferWriteOp::getEffects( 3312 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3313 &effects) { 3314 if (getShapedType().isa<MemRefType>()) 3315 effects.emplace_back(MemoryEffects::Write::get(), source(), 3316 SideEffects::DefaultResource::get()); 3317 } 3318 3319 namespace { 3320 /// Remove dead transfer write from the SSA chain so that it an be eliminated by 3321 /// DCE 3322 /// ``` 3323 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3324 /// : vector<1x4xf32>, tensor<4x4xf32> 3325 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} 3326 /// : vector<1x4xf32>, tensor<4x4xf32> 3327 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3328 /// : vector<1x4xf32>, tensor<4x4xf32> 3329 /// ``` 3330 /// 3331 /// into: 3332 /// 3333 /// ``` 3334 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3335 /// : vector<1x4xf32>, tensor<4x4xf32> 3336 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]} 3337 /// : vector<1x4xf32>, tensor<4x4xf32> 3338 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3339 /// : vector<1x4xf32>, tensor<4x4xf32> 3340 /// ``` 3341 /// 3342 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have 3343 /// any other uses. 3344 class FoldWaw final : public OpRewritePattern<TransferWriteOp> { 3345 public: 3346 using OpRewritePattern<TransferWriteOp>::OpRewritePattern; 3347 LogicalResult matchAndRewrite(TransferWriteOp writeOp, 3348 PatternRewriter &rewriter) const override { 3349 if (!writeOp.getShapedType().isa<RankedTensorType>()) 3350 return failure(); 3351 vector::TransferWriteOp writeToModify = writeOp; 3352 3353 auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>(); 3354 while (defWrite) { 3355 if (checkSameValueWAW(writeOp, defWrite)) { 3356 writeToModify.sourceMutable().assign(defWrite.source()); 3357 return success(); 3358 } 3359 if (!isDisjointTransferIndices( 3360 cast<VectorTransferOpInterface>(defWrite.getOperation()), 3361 cast<VectorTransferOpInterface>(writeOp.getOperation()))) 3362 break; 3363 // If the previous write op doesn't have any other use we an safely look 3364 // at the previous store to see if it can be removed. 3365 if (!defWrite->hasOneUse()) 3366 break; 3367 writeToModify = defWrite; 3368 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 3369 } 3370 return failure(); 3371 } 3372 }; 3373 3374 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write 3375 /// could directly write to the insert_slice's destination. E.g.: 3376 /// 3377 /// ``` 3378 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} 3379 /// : vector<4x5xf32>, tensor<4x5xf32> 3380 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] 3381 /// : tensor<4x5xf32> into tensor<?x?xf32> 3382 /// ``` 3383 /// is rewritten to: 3384 /// ``` 3385 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} 3386 /// : vector<4x5xf32>, tensor<?x?xf32> 3387 /// ``` 3388 struct FoldInsertSliceIntoTransferWrite 3389 : public OpRewritePattern<tensor::InsertSliceOp> { 3390 public: 3391 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 3392 3393 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 3394 PatternRewriter &rewriter) const override { 3395 if (!insertOp.hasUnitStride()) 3396 return failure(); 3397 3398 auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>(); 3399 if (!xferOp) 3400 return failure(); 3401 // TODO: support 0-d corner case. 3402 if (xferOp.getTransferRank() == 0) 3403 return failure(); 3404 3405 if (xferOp.hasOutOfBoundsDim()) 3406 return failure(); 3407 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) 3408 return failure(); 3409 if (xferOp.mask()) 3410 return failure(); 3411 // Fold only if the TransferWriteOp completely overwrites the `source` with 3412 // a vector. I.e., the result of the TransferWriteOp is a new tensor whose 3413 // content is the data of the vector. 3414 if (!llvm::equal(xferOp.getVectorType().getShape(), 3415 xferOp.getShapedType().getShape())) 3416 return failure(); 3417 if (!xferOp.permutation_map().isIdentity()) 3418 return failure(); 3419 3420 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3421 // dims are exactly the leading dims. I.e. the following is illegal: 3422 // ``` 3423 // %0 = vector.transfer_write %v, %t[0,0], %cst : 3424 // vector<2x4xf32>, tensor<2x4xf32> 3425 // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : 3426 // tensor<2x4xf32> into tensor<2x1x4xf32> 3427 // ``` 3428 // 3429 // Cannot fold into: 3430 // ``` 3431 // %0 = vector.transfer_write %v, %t[0,0,0], %cst : 3432 // vector<2x4xf32>, tensor<2x1x4xf32> 3433 // ``` 3434 // For this, check the trailing `vectorRank` dims of the insert_slice result 3435 // tensor match the trailing dims of the inferred result tensor. 3436 int64_t rankReduced = 3437 insertOp.getType().getRank() - insertOp.getSourceType().getRank(); 3438 int64_t vectorRank = xferOp.getVectorType().getRank(); 3439 RankedTensorType inferredSourceTensorType = 3440 tensor::ExtractSliceOp::inferResultType( 3441 insertOp.getType(), insertOp.getMixedOffsets(), 3442 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 3443 auto actualSourceTensorShape = insertOp.getSourceType().getShape(); 3444 if (rankReduced > 0 && 3445 actualSourceTensorShape.take_back(vectorRank) != 3446 inferredSourceTensorType.getShape().take_back(vectorRank)) 3447 return failure(); 3448 3449 SmallVector<Value> indices = getValueOrCreateConstantIndexOp( 3450 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); 3451 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3452 rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(), 3453 insertOp.dest(), indices, 3454 ArrayRef<bool>{inBounds}); 3455 return success(); 3456 } 3457 }; 3458 } // namespace 3459 3460 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, 3461 MLIRContext *context) { 3462 results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context); 3463 } 3464 3465 //===----------------------------------------------------------------------===// 3466 // LoadOp 3467 //===----------------------------------------------------------------------===// 3468 3469 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, 3470 MemRefType memRefTy) { 3471 if (!isLastMemrefDimUnitStride(memRefTy)) 3472 return op->emitOpError("most minor memref dim must have unit stride"); 3473 return success(); 3474 } 3475 3476 LogicalResult vector::LoadOp::verify() { 3477 VectorType resVecTy = getVectorType(); 3478 MemRefType memRefTy = getMemRefType(); 3479 3480 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) 3481 return failure(); 3482 3483 // Checks for vector memrefs. 3484 Type memElemTy = memRefTy.getElementType(); 3485 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3486 if (memVecTy != resVecTy) 3487 return emitOpError("base memref and result vector types should match"); 3488 memElemTy = memVecTy.getElementType(); 3489 } 3490 3491 if (resVecTy.getElementType() != memElemTy) 3492 return emitOpError("base and result element types should match"); 3493 if (llvm::size(indices()) != memRefTy.getRank()) 3494 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 3495 return success(); 3496 } 3497 3498 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) { 3499 if (succeeded(foldMemRefCast(*this))) 3500 return getResult(); 3501 return OpFoldResult(); 3502 } 3503 3504 //===----------------------------------------------------------------------===// 3505 // StoreOp 3506 //===----------------------------------------------------------------------===// 3507 3508 LogicalResult vector::StoreOp::verify() { 3509 VectorType valueVecTy = getVectorType(); 3510 MemRefType memRefTy = getMemRefType(); 3511 3512 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) 3513 return failure(); 3514 3515 // Checks for vector memrefs. 3516 Type memElemTy = memRefTy.getElementType(); 3517 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3518 if (memVecTy != valueVecTy) 3519 return emitOpError( 3520 "base memref and valueToStore vector types should match"); 3521 memElemTy = memVecTy.getElementType(); 3522 } 3523 3524 if (valueVecTy.getElementType() != memElemTy) 3525 return emitOpError("base and valueToStore element type should match"); 3526 if (llvm::size(indices()) != memRefTy.getRank()) 3527 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 3528 return success(); 3529 } 3530 3531 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands, 3532 SmallVectorImpl<OpFoldResult> &results) { 3533 return foldMemRefCast(*this); 3534 } 3535 3536 //===----------------------------------------------------------------------===// 3537 // MaskedLoadOp 3538 //===----------------------------------------------------------------------===// 3539 3540 LogicalResult MaskedLoadOp::verify() { 3541 VectorType maskVType = getMaskVectorType(); 3542 VectorType passVType = getPassThruVectorType(); 3543 VectorType resVType = getVectorType(); 3544 MemRefType memType = getMemRefType(); 3545 3546 if (resVType.getElementType() != memType.getElementType()) 3547 return emitOpError("base and result element type should match"); 3548 if (llvm::size(indices()) != memType.getRank()) 3549 return emitOpError("requires ") << memType.getRank() << " indices"; 3550 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3551 return emitOpError("expected result dim to match mask dim"); 3552 if (resVType != passVType) 3553 return emitOpError("expected pass_thru of same type as result type"); 3554 return success(); 3555 } 3556 3557 namespace { 3558 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { 3559 public: 3560 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern; 3561 LogicalResult matchAndRewrite(MaskedLoadOp load, 3562 PatternRewriter &rewriter) const override { 3563 switch (get1DMaskFormat(load.mask())) { 3564 case MaskFormat::AllTrue: 3565 rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(), 3566 load.base(), load.indices()); 3567 return success(); 3568 case MaskFormat::AllFalse: 3569 rewriter.replaceOp(load, load.pass_thru()); 3570 return success(); 3571 case MaskFormat::Unknown: 3572 return failure(); 3573 } 3574 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad"); 3575 } 3576 }; 3577 } // namespace 3578 3579 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3580 MLIRContext *context) { 3581 results.add<MaskedLoadFolder>(context); 3582 } 3583 3584 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) { 3585 if (succeeded(foldMemRefCast(*this))) 3586 return getResult(); 3587 return OpFoldResult(); 3588 } 3589 3590 //===----------------------------------------------------------------------===// 3591 // MaskedStoreOp 3592 //===----------------------------------------------------------------------===// 3593 3594 LogicalResult MaskedStoreOp::verify() { 3595 VectorType maskVType = getMaskVectorType(); 3596 VectorType valueVType = getVectorType(); 3597 MemRefType memType = getMemRefType(); 3598 3599 if (valueVType.getElementType() != memType.getElementType()) 3600 return emitOpError("base and valueToStore element type should match"); 3601 if (llvm::size(indices()) != memType.getRank()) 3602 return emitOpError("requires ") << memType.getRank() << " indices"; 3603 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3604 return emitOpError("expected valueToStore dim to match mask dim"); 3605 return success(); 3606 } 3607 3608 namespace { 3609 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { 3610 public: 3611 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern; 3612 LogicalResult matchAndRewrite(MaskedStoreOp store, 3613 PatternRewriter &rewriter) const override { 3614 switch (get1DMaskFormat(store.mask())) { 3615 case MaskFormat::AllTrue: 3616 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3617 store, store.valueToStore(), store.base(), store.indices()); 3618 return success(); 3619 case MaskFormat::AllFalse: 3620 rewriter.eraseOp(store); 3621 return success(); 3622 case MaskFormat::Unknown: 3623 return failure(); 3624 } 3625 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore"); 3626 } 3627 }; 3628 } // namespace 3629 3630 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3631 MLIRContext *context) { 3632 results.add<MaskedStoreFolder>(context); 3633 } 3634 3635 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands, 3636 SmallVectorImpl<OpFoldResult> &results) { 3637 return foldMemRefCast(*this); 3638 } 3639 3640 //===----------------------------------------------------------------------===// 3641 // GatherOp 3642 //===----------------------------------------------------------------------===// 3643 3644 LogicalResult GatherOp::verify() { 3645 VectorType indVType = getIndexVectorType(); 3646 VectorType maskVType = getMaskVectorType(); 3647 VectorType resVType = getVectorType(); 3648 MemRefType memType = getMemRefType(); 3649 3650 if (resVType.getElementType() != memType.getElementType()) 3651 return emitOpError("base and result element type should match"); 3652 if (llvm::size(indices()) != memType.getRank()) 3653 return emitOpError("requires ") << memType.getRank() << " indices"; 3654 if (resVType.getDimSize(0) != indVType.getDimSize(0)) 3655 return emitOpError("expected result dim to match indices dim"); 3656 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3657 return emitOpError("expected result dim to match mask dim"); 3658 if (resVType != getPassThruVectorType()) 3659 return emitOpError("expected pass_thru of same type as result type"); 3660 return success(); 3661 } 3662 3663 namespace { 3664 class GatherFolder final : public OpRewritePattern<GatherOp> { 3665 public: 3666 using OpRewritePattern<GatherOp>::OpRewritePattern; 3667 LogicalResult matchAndRewrite(GatherOp gather, 3668 PatternRewriter &rewriter) const override { 3669 switch (get1DMaskFormat(gather.mask())) { 3670 case MaskFormat::AllTrue: 3671 return failure(); // no unmasked equivalent 3672 case MaskFormat::AllFalse: 3673 rewriter.replaceOp(gather, gather.pass_thru()); 3674 return success(); 3675 case MaskFormat::Unknown: 3676 return failure(); 3677 } 3678 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); 3679 } 3680 }; 3681 } // namespace 3682 3683 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, 3684 MLIRContext *context) { 3685 results.add<GatherFolder>(context); 3686 } 3687 3688 //===----------------------------------------------------------------------===// 3689 // ScatterOp 3690 //===----------------------------------------------------------------------===// 3691 3692 LogicalResult ScatterOp::verify() { 3693 VectorType indVType = getIndexVectorType(); 3694 VectorType maskVType = getMaskVectorType(); 3695 VectorType valueVType = getVectorType(); 3696 MemRefType memType = getMemRefType(); 3697 3698 if (valueVType.getElementType() != memType.getElementType()) 3699 return emitOpError("base and valueToStore element type should match"); 3700 if (llvm::size(indices()) != memType.getRank()) 3701 return emitOpError("requires ") << memType.getRank() << " indices"; 3702 if (valueVType.getDimSize(0) != indVType.getDimSize(0)) 3703 return emitOpError("expected valueToStore dim to match indices dim"); 3704 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3705 return emitOpError("expected valueToStore dim to match mask dim"); 3706 return success(); 3707 } 3708 3709 namespace { 3710 class ScatterFolder final : public OpRewritePattern<ScatterOp> { 3711 public: 3712 using OpRewritePattern<ScatterOp>::OpRewritePattern; 3713 LogicalResult matchAndRewrite(ScatterOp scatter, 3714 PatternRewriter &rewriter) const override { 3715 switch (get1DMaskFormat(scatter.mask())) { 3716 case MaskFormat::AllTrue: 3717 return failure(); // no unmasked equivalent 3718 case MaskFormat::AllFalse: 3719 rewriter.eraseOp(scatter); 3720 return success(); 3721 case MaskFormat::Unknown: 3722 return failure(); 3723 } 3724 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); 3725 } 3726 }; 3727 } // namespace 3728 3729 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, 3730 MLIRContext *context) { 3731 results.add<ScatterFolder>(context); 3732 } 3733 3734 //===----------------------------------------------------------------------===// 3735 // ExpandLoadOp 3736 //===----------------------------------------------------------------------===// 3737 3738 LogicalResult ExpandLoadOp::verify() { 3739 VectorType maskVType = getMaskVectorType(); 3740 VectorType passVType = getPassThruVectorType(); 3741 VectorType resVType = getVectorType(); 3742 MemRefType memType = getMemRefType(); 3743 3744 if (resVType.getElementType() != memType.getElementType()) 3745 return emitOpError("base and result element type should match"); 3746 if (llvm::size(indices()) != memType.getRank()) 3747 return emitOpError("requires ") << memType.getRank() << " indices"; 3748 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3749 return emitOpError("expected result dim to match mask dim"); 3750 if (resVType != passVType) 3751 return emitOpError("expected pass_thru of same type as result type"); 3752 return success(); 3753 } 3754 3755 namespace { 3756 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { 3757 public: 3758 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern; 3759 LogicalResult matchAndRewrite(ExpandLoadOp expand, 3760 PatternRewriter &rewriter) const override { 3761 switch (get1DMaskFormat(expand.mask())) { 3762 case MaskFormat::AllTrue: 3763 rewriter.replaceOpWithNewOp<vector::LoadOp>( 3764 expand, expand.getType(), expand.base(), expand.indices()); 3765 return success(); 3766 case MaskFormat::AllFalse: 3767 rewriter.replaceOp(expand, expand.pass_thru()); 3768 return success(); 3769 case MaskFormat::Unknown: 3770 return failure(); 3771 } 3772 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder"); 3773 } 3774 }; 3775 } // namespace 3776 3777 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3778 MLIRContext *context) { 3779 results.add<ExpandLoadFolder>(context); 3780 } 3781 3782 //===----------------------------------------------------------------------===// 3783 // CompressStoreOp 3784 //===----------------------------------------------------------------------===// 3785 3786 LogicalResult CompressStoreOp::verify() { 3787 VectorType maskVType = getMaskVectorType(); 3788 VectorType valueVType = getVectorType(); 3789 MemRefType memType = getMemRefType(); 3790 3791 if (valueVType.getElementType() != memType.getElementType()) 3792 return emitOpError("base and valueToStore element type should match"); 3793 if (llvm::size(indices()) != memType.getRank()) 3794 return emitOpError("requires ") << memType.getRank() << " indices"; 3795 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3796 return emitOpError("expected valueToStore dim to match mask dim"); 3797 return success(); 3798 } 3799 3800 namespace { 3801 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { 3802 public: 3803 using OpRewritePattern<CompressStoreOp>::OpRewritePattern; 3804 LogicalResult matchAndRewrite(CompressStoreOp compress, 3805 PatternRewriter &rewriter) const override { 3806 switch (get1DMaskFormat(compress.mask())) { 3807 case MaskFormat::AllTrue: 3808 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3809 compress, compress.valueToStore(), compress.base(), 3810 compress.indices()); 3811 return success(); 3812 case MaskFormat::AllFalse: 3813 rewriter.eraseOp(compress); 3814 return success(); 3815 case MaskFormat::Unknown: 3816 return failure(); 3817 } 3818 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder"); 3819 } 3820 }; 3821 } // namespace 3822 3823 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3824 MLIRContext *context) { 3825 results.add<CompressStoreFolder>(context); 3826 } 3827 3828 //===----------------------------------------------------------------------===// 3829 // ShapeCastOp 3830 //===----------------------------------------------------------------------===// 3831 3832 /// Returns true if each element of 'a' is equal to the product of a contiguous 3833 /// sequence of the elements of 'b'. Returns false otherwise. 3834 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 3835 unsigned rankA = a.size(); 3836 unsigned rankB = b.size(); 3837 assert(rankA < rankB); 3838 3839 unsigned i = 0; 3840 unsigned j = 0; 3841 while (i < rankA && j < rankB) { 3842 int64_t dimA = a[i]; 3843 int64_t dimB = 1; 3844 while (dimB < dimA && j < rankB) 3845 dimB *= b[j++]; 3846 if (dimA != dimB) 3847 break; 3848 ++i; 3849 3850 // Handle the case when trailing dimensions are of size 1. 3851 // Include them into the contiguous sequence. 3852 auto isOne = [](int64_t v) { return v == 1; }; 3853 if (i < rankA && llvm::all_of(a.slice(i), isOne)) 3854 i = rankA; 3855 if (j < rankB && llvm::all_of(b.slice(j), isOne)) 3856 j = rankB; 3857 } 3858 3859 return i == rankA && j == rankB; 3860 } 3861 3862 static LogicalResult verifyVectorShapeCast(Operation *op, 3863 VectorType sourceVectorType, 3864 VectorType resultVectorType) { 3865 // Check that element type is the same. 3866 if (sourceVectorType.getElementType() != resultVectorType.getElementType()) 3867 return op->emitOpError("source/result vectors must have same element type"); 3868 auto sourceShape = sourceVectorType.getShape(); 3869 auto resultShape = resultVectorType.getShape(); 3870 3871 // Check that product of source dim sizes matches product of result dim sizes. 3872 int64_t sourceDimProduct = std::accumulate( 3873 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); 3874 int64_t resultDimProduct = std::accumulate( 3875 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); 3876 if (sourceDimProduct != resultDimProduct) 3877 return op->emitOpError("source/result number of elements must match"); 3878 3879 // Check that expanding/contracting rank cases. 3880 unsigned sourceRank = sourceVectorType.getRank(); 3881 unsigned resultRank = resultVectorType.getRank(); 3882 if (sourceRank < resultRank) { 3883 if (!isValidShapeCast(sourceShape, resultShape)) 3884 return op->emitOpError("invalid shape cast"); 3885 } else if (sourceRank > resultRank) { 3886 if (!isValidShapeCast(resultShape, sourceShape)) 3887 return op->emitOpError("invalid shape cast"); 3888 } 3889 return success(); 3890 } 3891 3892 LogicalResult ShapeCastOp::verify() { 3893 auto sourceVectorType = source().getType().dyn_cast_or_null<VectorType>(); 3894 auto resultVectorType = result().getType().dyn_cast_or_null<VectorType>(); 3895 3896 // Check if source/result are of vector type. 3897 if (sourceVectorType && resultVectorType) 3898 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); 3899 3900 return success(); 3901 } 3902 3903 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { 3904 // Nop shape cast. 3905 if (source().getType() == result().getType()) 3906 return source(); 3907 3908 // Canceling shape casts. 3909 if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) { 3910 if (result().getType() == otherOp.source().getType()) 3911 return otherOp.source(); 3912 3913 // Only allows valid transitive folding. 3914 VectorType srcType = otherOp.source().getType().cast<VectorType>(); 3915 VectorType resultType = getResult().getType().cast<VectorType>(); 3916 if (srcType.getRank() < resultType.getRank()) { 3917 if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) 3918 return {}; 3919 } else if (srcType.getRank() > resultType.getRank()) { 3920 if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) 3921 return {}; 3922 } else { 3923 return {}; 3924 } 3925 3926 setOperand(otherOp.source()); 3927 return getResult(); 3928 } 3929 return {}; 3930 } 3931 3932 namespace { 3933 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. 3934 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { 3935 public: 3936 using OpRewritePattern<ShapeCastOp>::OpRewritePattern; 3937 3938 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, 3939 PatternRewriter &rewriter) const override { 3940 auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>(); 3941 if (!constantOp) 3942 return failure(); 3943 // Only handle splat for now. 3944 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 3945 if (!dense) 3946 return failure(); 3947 auto newAttr = 3948 DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(), 3949 dense.getSplatValue<Attribute>()); 3950 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); 3951 return success(); 3952 } 3953 }; 3954 3955 } // namespace 3956 3957 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 3958 MLIRContext *context) { 3959 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. 3960 results.add<ShapeCastConstantFolder>(context); 3961 } 3962 3963 //===----------------------------------------------------------------------===// 3964 // VectorBitCastOp 3965 //===----------------------------------------------------------------------===// 3966 3967 LogicalResult BitCastOp::verify() { 3968 auto sourceVectorType = getSourceVectorType(); 3969 auto resultVectorType = getResultVectorType(); 3970 3971 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { 3972 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) 3973 return emitOpError("dimension size mismatch at: ") << i; 3974 } 3975 3976 DataLayout dataLayout = DataLayout::closest(*this); 3977 auto sourceElementBits = 3978 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); 3979 auto resultElementBits = 3980 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); 3981 3982 if (sourceVectorType.getRank() == 0) { 3983 if (sourceElementBits != resultElementBits) 3984 return emitOpError("source/result bitwidth of the 0-D vector element " 3985 "types must be equal"); 3986 } else if (sourceElementBits * sourceVectorType.getShape().back() != 3987 resultElementBits * resultVectorType.getShape().back()) { 3988 return emitOpError( 3989 "source/result bitwidth of the minor 1-D vectors must be equal"); 3990 } 3991 3992 return success(); 3993 } 3994 3995 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) { 3996 // Nop cast. 3997 if (source().getType() == result().getType()) 3998 return source(); 3999 4000 // Canceling bitcasts. 4001 if (auto otherOp = source().getDefiningOp<BitCastOp>()) 4002 if (result().getType() == otherOp.source().getType()) 4003 return otherOp.source(); 4004 4005 Attribute sourceConstant = operands.front(); 4006 if (!sourceConstant) 4007 return {}; 4008 4009 Type srcElemType = getSourceVectorType().getElementType(); 4010 Type dstElemType = getResultVectorType().getElementType(); 4011 4012 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) { 4013 if (floatPack.isSplat()) { 4014 auto splat = floatPack.getSplatValue<FloatAttr>(); 4015 4016 // Casting fp16 into fp32. 4017 if (srcElemType.isF16() && dstElemType.isF32()) { 4018 uint32_t bits = static_cast<uint32_t>( 4019 splat.getValue().bitcastToAPInt().getZExtValue()); 4020 // Duplicate the 16-bit pattern. 4021 bits = (bits << 16) | (bits & 0xffff); 4022 APInt intBits(32, bits); 4023 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); 4024 return DenseElementsAttr::get(getResultVectorType(), floatBits); 4025 } 4026 } 4027 } 4028 4029 return {}; 4030 } 4031 4032 //===----------------------------------------------------------------------===// 4033 // TypeCastOp 4034 //===----------------------------------------------------------------------===// 4035 4036 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { 4037 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>(); 4038 SmallVector<int64_t, 8> res(memRefType.getShape().begin(), 4039 memRefType.getShape().end()); 4040 if (vectorType) 4041 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); 4042 return res; 4043 } 4044 4045 /// Build the canonical memRefType with a single vector. 4046 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. 4047 void TypeCastOp::build(OpBuilder &builder, OperationState &result, 4048 Value source) { 4049 result.addOperands(source); 4050 MemRefType memRefType = source.getType().cast<MemRefType>(); 4051 VectorType vectorType = 4052 VectorType::get(extractShape(memRefType), 4053 getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); 4054 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), 4055 memRefType.getMemorySpace())); 4056 } 4057 4058 LogicalResult TypeCastOp::verify() { 4059 MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); 4060 if (!canonicalType.getLayout().isIdentity()) 4061 return emitOpError("expects operand to be a memref with identity layout"); 4062 if (!getResultMemRefType().getLayout().isIdentity()) 4063 return emitOpError("expects result to be a memref with identity layout"); 4064 if (getResultMemRefType().getMemorySpace() != 4065 getMemRefType().getMemorySpace()) 4066 return emitOpError("expects result in same memory space"); 4067 4068 auto sourceType = getMemRefType(); 4069 auto resultType = getResultMemRefType(); 4070 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != 4071 getElementTypeOrSelf(getElementTypeOrSelf(resultType))) 4072 return emitOpError( 4073 "expects result and operand with same underlying scalar type: ") 4074 << resultType; 4075 if (extractShape(sourceType) != extractShape(resultType)) 4076 return emitOpError( 4077 "expects concatenated result and operand shapes to be equal: ") 4078 << resultType; 4079 return success(); 4080 } 4081 4082 //===----------------------------------------------------------------------===// 4083 // TransposeOp 4084 //===----------------------------------------------------------------------===// 4085 4086 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, 4087 Value vector, ArrayRef<int64_t> transp) { 4088 VectorType vt = vector.getType().cast<VectorType>(); 4089 SmallVector<int64_t, 4> transposedShape(vt.getRank()); 4090 for (unsigned i = 0; i < transp.size(); ++i) 4091 transposedShape[i] = vt.getShape()[transp[i]]; 4092 4093 result.addOperands(vector); 4094 result.addTypes(VectorType::get(transposedShape, vt.getElementType())); 4095 result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); 4096 } 4097 4098 // Eliminates transpose operations, which produce values identical to their 4099 // input values. This happens when the dimensions of the input vector remain in 4100 // their original order after the transpose operation. 4101 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) { 4102 SmallVector<int64_t, 4> transp; 4103 getTransp(transp); 4104 4105 // Check if the permutation of the dimensions contains sequential values: 4106 // {0, 1, 2, ...}. 4107 for (int64_t i = 0, e = transp.size(); i < e; i++) { 4108 if (transp[i] != i) 4109 return {}; 4110 } 4111 4112 return vector(); 4113 } 4114 4115 LogicalResult vector::TransposeOp::verify() { 4116 VectorType vectorType = getVectorType(); 4117 VectorType resultType = getResultType(); 4118 int64_t rank = resultType.getRank(); 4119 if (vectorType.getRank() != rank) 4120 return emitOpError("vector result rank mismatch: ") << rank; 4121 // Verify transposition array. 4122 auto transpAttr = transp().getValue(); 4123 int64_t size = transpAttr.size(); 4124 if (rank != size) 4125 return emitOpError("transposition length mismatch: ") << size; 4126 SmallVector<bool, 8> seen(rank, false); 4127 for (const auto &ta : llvm::enumerate(transpAttr)) { 4128 int64_t i = ta.value().cast<IntegerAttr>().getInt(); 4129 if (i < 0 || i >= rank) 4130 return emitOpError("transposition index out of range: ") << i; 4131 if (seen[i]) 4132 return emitOpError("duplicate position index: ") << i; 4133 seen[i] = true; 4134 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) 4135 return emitOpError("dimension size mismatch at: ") << i; 4136 } 4137 return success(); 4138 } 4139 4140 namespace { 4141 4142 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. 4143 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { 4144 public: 4145 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 4146 4147 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 4148 PatternRewriter &rewriter) const override { 4149 // Wrapper around vector::TransposeOp::getTransp() for cleaner code. 4150 auto getPermutation = [](vector::TransposeOp transpose) { 4151 SmallVector<int64_t, 4> permutation; 4152 transpose.getTransp(permutation); 4153 return permutation; 4154 }; 4155 4156 // Composes two permutations: result[i] = permutation1[permutation2[i]]. 4157 auto composePermutations = [](ArrayRef<int64_t> permutation1, 4158 ArrayRef<int64_t> permutation2) { 4159 SmallVector<int64_t, 4> result; 4160 for (auto index : permutation2) 4161 result.push_back(permutation1[index]); 4162 return result; 4163 }; 4164 4165 // Return if the input of 'transposeOp' is not defined by another transpose. 4166 vector::TransposeOp parentTransposeOp = 4167 transposeOp.vector().getDefiningOp<vector::TransposeOp>(); 4168 if (!parentTransposeOp) 4169 return failure(); 4170 4171 SmallVector<int64_t, 4> permutation = composePermutations( 4172 getPermutation(parentTransposeOp), getPermutation(transposeOp)); 4173 // Replace 'transposeOp' with a new transpose operation. 4174 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 4175 transposeOp, transposeOp.getResult().getType(), 4176 parentTransposeOp.vector(), 4177 vector::getVectorSubscriptAttr(rewriter, permutation)); 4178 return success(); 4179 } 4180 }; 4181 4182 } // namespace 4183 4184 void vector::TransposeOp::getCanonicalizationPatterns( 4185 RewritePatternSet &results, MLIRContext *context) { 4186 results.add<TransposeFolder>(context); 4187 } 4188 4189 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) { 4190 populateFromInt64AttrArray(transp(), results); 4191 } 4192 4193 //===----------------------------------------------------------------------===// 4194 // ConstantMaskOp 4195 //===----------------------------------------------------------------------===// 4196 4197 LogicalResult ConstantMaskOp::verify() { 4198 auto resultType = getResult().getType().cast<VectorType>(); 4199 // Check the corner case of 0-D vectors first. 4200 if (resultType.getRank() == 0) { 4201 if (mask_dim_sizes().size() != 1) 4202 return emitError("array attr must have length 1 for 0-D vectors"); 4203 auto dim = mask_dim_sizes()[0].cast<IntegerAttr>().getInt(); 4204 if (dim != 0 && dim != 1) 4205 return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); 4206 return success(); 4207 } 4208 4209 // Verify that array attr size matches the rank of the vector result. 4210 if (static_cast<int64_t>(mask_dim_sizes().size()) != resultType.getRank()) 4211 return emitOpError( 4212 "must specify array attr of size equal vector result rank"); 4213 // Verify that each array attr element is in bounds of corresponding vector 4214 // result dimension size. 4215 auto resultShape = resultType.getShape(); 4216 SmallVector<int64_t, 4> maskDimSizes; 4217 for (const auto &it : llvm::enumerate(mask_dim_sizes())) { 4218 int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); 4219 if (attrValue < 0 || attrValue > resultShape[it.index()]) 4220 return emitOpError( 4221 "array attr of size out of bounds of vector result dimension size"); 4222 maskDimSizes.push_back(attrValue); 4223 } 4224 // Verify that if one mask dim size is zero, they all should be zero (because 4225 // the mask region is a conjunction of each mask dimension interval). 4226 bool anyZeros = llvm::is_contained(maskDimSizes, 0); 4227 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); 4228 if (anyZeros && !allZeros) 4229 return emitOpError("expected all mask dim sizes to be zeros, " 4230 "as a result of conjunction with zero mask dim"); 4231 return success(); 4232 } 4233 4234 //===----------------------------------------------------------------------===// 4235 // CreateMaskOp 4236 //===----------------------------------------------------------------------===// 4237 4238 LogicalResult CreateMaskOp::verify() { 4239 auto vectorType = getResult().getType().cast<VectorType>(); 4240 // Verify that an operand was specified for each result vector each dimension. 4241 if (vectorType.getRank() == 0) { 4242 if (getNumOperands() != 1) 4243 return emitOpError( 4244 "must specify exactly one operand for 0-D create_mask"); 4245 } else if (getNumOperands() != 4246 getResult().getType().cast<VectorType>().getRank()) { 4247 return emitOpError( 4248 "must specify an operand for each result vector dimension"); 4249 } 4250 return success(); 4251 } 4252 4253 namespace { 4254 4255 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. 4256 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { 4257 public: 4258 using OpRewritePattern<CreateMaskOp>::OpRewritePattern; 4259 4260 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, 4261 PatternRewriter &rewriter) const override { 4262 // Return if any of 'createMaskOp' operands are not defined by a constant. 4263 auto isNotDefByConstant = [](Value operand) { 4264 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 4265 }; 4266 if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant)) 4267 return failure(); 4268 // Gather constant mask dimension sizes. 4269 SmallVector<int64_t, 4> maskDimSizes; 4270 for (auto it : llvm::zip(createMaskOp.operands(), 4271 createMaskOp.getType().getShape())) { 4272 auto *defOp = std::get<0>(it).getDefiningOp(); 4273 int64_t maxDimSize = std::get<1>(it); 4274 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value(); 4275 dimSize = std::min(dimSize, maxDimSize); 4276 // If one of dim sizes is zero, set all dims to zero. 4277 if (dimSize <= 0) { 4278 maskDimSizes.assign(createMaskOp.getType().getRank(), 0); 4279 break; 4280 } 4281 maskDimSizes.push_back(dimSize); 4282 } 4283 // Replace 'createMaskOp' with ConstantMaskOp. 4284 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 4285 createMaskOp, createMaskOp.getResult().getType(), 4286 vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); 4287 return success(); 4288 } 4289 }; 4290 4291 } // namespace 4292 4293 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, 4294 MLIRContext *context) { 4295 results.add<CreateMaskFolder>(context); 4296 } 4297 4298 //===----------------------------------------------------------------------===// 4299 // ScanOp 4300 //===----------------------------------------------------------------------===// 4301 4302 LogicalResult ScanOp::verify() { 4303 VectorType srcType = getSourceType(); 4304 VectorType initialType = getInitialValueType(); 4305 // Check reduction dimension < rank. 4306 int64_t srcRank = srcType.getRank(); 4307 int64_t reductionDim = reduction_dim(); 4308 if (reductionDim >= srcRank) 4309 return emitOpError("reduction dimension ") 4310 << reductionDim << " has to be less than " << srcRank; 4311 4312 // Check that rank(initial_value) = rank(src) - 1. 4313 int64_t initialValueRank = initialType.getRank(); 4314 if (initialValueRank != srcRank - 1) 4315 return emitOpError("initial value rank ") 4316 << initialValueRank << " has to be equal to " << srcRank - 1; 4317 4318 // Check shapes of initial value and src. 4319 ArrayRef<int64_t> srcShape = srcType.getShape(); 4320 ArrayRef<int64_t> initialValueShapes = initialType.getShape(); 4321 SmallVector<int64_t> expectedShape; 4322 for (int i = 0; i < srcRank; i++) { 4323 if (i != reductionDim) 4324 expectedShape.push_back(srcShape[i]); 4325 } 4326 if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape), 4327 [](std::tuple<int64_t, int64_t> s) { 4328 return std::get<0>(s) != std::get<1>(s); 4329 })) { 4330 return emitOpError("incompatible input/initial value shapes"); 4331 } 4332 4333 return success(); 4334 } 4335 4336 void mlir::vector::populateVectorToVectorCanonicalizationPatterns( 4337 RewritePatternSet &patterns) { 4338 patterns 4339 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, 4340 ScatterFolder, ExpandLoadFolder, CompressStoreFolder, 4341 StridedSliceConstantMaskFolder, TransposeFolder>( 4342 patterns.getContext()); 4343 } 4344 4345 //===----------------------------------------------------------------------===// 4346 // SplatOp 4347 //===----------------------------------------------------------------------===// 4348 4349 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { 4350 auto constOperand = operands.front(); 4351 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>()) 4352 return {}; 4353 4354 // SplatElementsAttr::get treats single value for second arg as being a splat. 4355 return SplatElementsAttr::get(getType(), {constOperand}); 4356 } 4357 4358 //===----------------------------------------------------------------------===// 4359 // TableGen'd op method definitions 4360 //===----------------------------------------------------------------------===// 4361 4362 #define GET_OP_CLASSES 4363 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 4364