1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements convenience types for working with super-vectorization 10 // operations, in particular super-vector loads and stores. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/IndexingUtils.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/BuiltinOps.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/DialectImplementation.h" 29 #include "mlir/IR/OpImplementation.h" 30 #include "mlir/IR/PatternMatch.h" 31 #include "mlir/IR/TypeUtilities.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Support/MathExtras.h" 34 #include "llvm/ADT/StringSet.h" 35 #include "llvm/ADT/bit.h" 36 #include <numeric> 37 38 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" 39 // Pull in all enum type and utility function definitions. 40 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" 41 42 using namespace mlir; 43 using namespace mlir::vector; 44 45 /// Helper enum to classify mask value. 46 enum class MaskFormat { 47 AllTrue = 0, 48 AllFalse = 1, 49 Unknown = 2, 50 }; 51 52 /// Helper method to classify a 1-D mask value. Currently, the method 53 /// looks "under the hood" of a constant value with dense attributes 54 /// and a constant mask operation (since the client may be called at 55 /// various stages during progressive lowering). 56 static MaskFormat get1DMaskFormat(Value mask) { 57 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) { 58 // Inspect constant dense values. We count up for bits that 59 // are set, count down for bits that are cleared, and bail 60 // when a mix is detected. 61 if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) { 62 int64_t val = 0; 63 for (bool b : denseElts.getValues<bool>()) 64 if (b && val >= 0) 65 val++; 66 else if (!b && val <= 0) 67 val--; 68 else 69 return MaskFormat::Unknown; 70 if (val > 0) 71 return MaskFormat::AllTrue; 72 if (val < 0) 73 return MaskFormat::AllFalse; 74 } 75 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { 76 // Inspect constant mask index. If the index exceeds the 77 // dimension size, all bits are set. If the index is zero 78 // or less, no bits are set. 79 ArrayAttr masks = m.mask_dim_sizes(); 80 assert(masks.size() == 1); 81 int64_t i = masks[0].cast<IntegerAttr>().getInt(); 82 int64_t u = m.getType().getDimSize(0); 83 if (i >= u) 84 return MaskFormat::AllTrue; 85 if (i <= 0) 86 return MaskFormat::AllFalse; 87 } 88 return MaskFormat::Unknown; 89 } 90 91 // Helper for verifying combining kinds in contractions and reductions. 92 static bool isSupportedCombiningKind(CombiningKind combiningKind, 93 Type elementType) { 94 switch (combiningKind) { 95 case CombiningKind::ADD: 96 case CombiningKind::MUL: 97 return elementType.isIntOrIndexOrFloat(); 98 case CombiningKind::MINUI: 99 case CombiningKind::MINSI: 100 case CombiningKind::MAXUI: 101 case CombiningKind::MAXSI: 102 case CombiningKind::AND: 103 case CombiningKind::OR: 104 case CombiningKind::XOR: 105 return elementType.isIntOrIndex(); 106 case CombiningKind::MINF: 107 case CombiningKind::MAXF: 108 return elementType.isa<FloatType>(); 109 } 110 return false; 111 } 112 113 /// Return true if the last dimension of the MemRefType has unit stride. Also 114 /// return true for memrefs with no strides. 115 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { 116 int64_t offset; 117 SmallVector<int64_t> strides; 118 auto successStrides = getStridesAndOffset(type, strides, offset); 119 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 120 } 121 122 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, 123 VectorType vectorType) { 124 int64_t elementVectorRank = 0; 125 VectorType elementVectorType = 126 shapedType.getElementType().dyn_cast<VectorType>(); 127 if (elementVectorType) 128 elementVectorRank += elementVectorType.getRank(); 129 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>. 130 // TODO: replace once we have 0-d vectors. 131 if (shapedType.getRank() == 0 && 132 vectorType.getShape() == ArrayRef<int64_t>{1}) 133 return AffineMap::get( 134 /*numDims=*/0, /*numSymbols=*/0, 135 getAffineConstantExpr(0, shapedType.getContext())); 136 return AffineMap::getMinorIdentityMap( 137 shapedType.getRank(), vectorType.getRank() - elementVectorRank, 138 shapedType.getContext()); 139 } 140 141 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, 142 vector::TransferReadOp read) { 143 return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && 144 defWrite.indices() == read.indices() && 145 defWrite.getVectorType() == read.getVectorType() && 146 defWrite.permutation_map() == read.permutation_map(); 147 } 148 149 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, 150 vector::TransferWriteOp priorWrite) { 151 return priorWrite.indices() == write.indices() && 152 priorWrite.mask() == write.mask() && 153 priorWrite.getVectorType() == write.getVectorType() && 154 priorWrite.permutation_map() == write.permutation_map(); 155 } 156 157 bool mlir::vector::isDisjointTransferIndices( 158 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { 159 // For simplicity only look at transfer of same type. 160 if (transferA.getVectorType() != transferB.getVectorType()) 161 return false; 162 unsigned rankOffset = transferA.getLeadingShapedRank(); 163 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { 164 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>(); 165 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>(); 166 // If any of the indices are dynamic we cannot prove anything. 167 if (!indexA || !indexB) 168 continue; 169 170 if (i < rankOffset) { 171 // For leading dimensions, if we can prove that index are different we 172 // know we are accessing disjoint slices. 173 if (indexA.getValue().cast<IntegerAttr>().getInt() != 174 indexB.getValue().cast<IntegerAttr>().getInt()) 175 return true; 176 } else { 177 // For this dimension, we slice a part of the memref we need to make sure 178 // the intervals accessed don't overlap. 179 int64_t distance = 180 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() - 181 indexB.getValue().cast<IntegerAttr>().getInt()); 182 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) 183 return true; 184 } 185 } 186 return false; 187 } 188 189 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, 190 VectorTransferOpInterface transferB) { 191 if (transferA.source() != transferB.source()) 192 return false; 193 return isDisjointTransferIndices(transferA, transferB); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // CombiningKindAttr 198 //===----------------------------------------------------------------------===// 199 200 namespace mlir { 201 namespace vector { 202 namespace detail { 203 struct BitmaskEnumStorage : public AttributeStorage { 204 using KeyTy = uint64_t; 205 206 BitmaskEnumStorage(KeyTy val) : value(val) {} 207 208 bool operator==(const KeyTy &key) const { return value == key; } 209 210 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, 211 const KeyTy &key) { 212 return new (allocator.allocate<BitmaskEnumStorage>()) 213 BitmaskEnumStorage(key); 214 } 215 216 KeyTy value = 0; 217 }; 218 } // namespace detail 219 } // namespace vector 220 } // namespace mlir 221 222 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind, 223 MLIRContext *context) { 224 return Base::get(context, static_cast<uint64_t>(kind)); 225 } 226 227 CombiningKind CombiningKindAttr::getKind() const { 228 return static_cast<CombiningKind>(getImpl()->value); 229 } 230 231 static constexpr const CombiningKind combiningKindsList[] = { 232 // clang-format off 233 CombiningKind::ADD, 234 CombiningKind::MUL, 235 CombiningKind::MINUI, 236 CombiningKind::MINSI, 237 CombiningKind::MINF, 238 CombiningKind::MAXUI, 239 CombiningKind::MAXSI, 240 CombiningKind::MAXF, 241 CombiningKind::AND, 242 CombiningKind::OR, 243 CombiningKind::XOR, 244 // clang-format on 245 }; 246 247 void CombiningKindAttr::print(AsmPrinter &printer) const { 248 printer << "<"; 249 auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { 250 return bitEnumContains(this->getKind(), kind); 251 }); 252 llvm::interleaveComma(kinds, printer, 253 [&](auto kind) { printer << stringifyEnum(kind); }); 254 printer << ">"; 255 } 256 257 Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) { 258 if (failed(parser.parseLess())) 259 return {}; 260 261 StringRef elemName; 262 if (failed(parser.parseKeyword(&elemName))) 263 return {}; 264 265 auto kind = symbolizeCombiningKind(elemName); 266 if (!kind) { 267 parser.emitError(parser.getNameLoc(), "Unknown combining kind: ") 268 << elemName; 269 return {}; 270 } 271 272 if (failed(parser.parseGreater())) 273 return {}; 274 275 return CombiningKindAttr::get(kind.getValue(), parser.getContext()); 276 } 277 278 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser, 279 Type type) const { 280 StringRef attrKind; 281 if (parser.parseKeyword(&attrKind)) 282 return {}; 283 284 if (attrKind == "kind") 285 return CombiningKindAttr::parse(parser, {}); 286 287 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; 288 return {}; 289 } 290 291 void VectorDialect::printAttribute(Attribute attr, 292 DialectAsmPrinter &os) const { 293 if (auto ck = attr.dyn_cast<CombiningKindAttr>()) { 294 os << "kind"; 295 ck.print(os); 296 return; 297 } 298 llvm_unreachable("Unknown attribute type"); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // VectorDialect 303 //===----------------------------------------------------------------------===// 304 305 void VectorDialect::initialize() { 306 addAttributes<CombiningKindAttr>(); 307 308 addOperations< 309 #define GET_OP_LIST 310 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 311 >(); 312 } 313 314 /// Materialize a single constant operation from a given attribute value with 315 /// the desired resultant type. 316 Operation *VectorDialect::materializeConstant(OpBuilder &builder, 317 Attribute value, Type type, 318 Location loc) { 319 return builder.create<arith::ConstantOp>(loc, type, value); 320 } 321 322 IntegerType vector::getVectorSubscriptType(Builder &builder) { 323 return builder.getIntegerType(64); 324 } 325 326 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, 327 ArrayRef<int64_t> values) { 328 return builder.getI64ArrayAttr(values); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // MultiDimReductionOp 333 //===----------------------------------------------------------------------===// 334 335 void vector::MultiDimReductionOp::build(OpBuilder &builder, 336 OperationState &result, Value source, 337 ArrayRef<bool> reductionMask, 338 CombiningKind kind) { 339 result.addOperands(source); 340 auto sourceVectorType = source.getType().cast<VectorType>(); 341 auto targetType = MultiDimReductionOp::inferDestType( 342 sourceVectorType.getShape(), reductionMask, 343 sourceVectorType.getElementType()); 344 result.addTypes(targetType); 345 346 SmallVector<int64_t> reductionDims; 347 for (const auto &en : llvm::enumerate(reductionMask)) 348 if (en.value()) 349 reductionDims.push_back(en.index()); 350 result.addAttribute(getReductionDimsAttrName(), 351 builder.getI64ArrayAttr(reductionDims)); 352 result.addAttribute(getKindAttrName(), 353 CombiningKindAttr::get(kind, builder.getContext())); 354 } 355 356 LogicalResult MultiDimReductionOp::verify() { 357 auto reductionMask = getReductionMask(); 358 auto targetType = MultiDimReductionOp::inferDestType( 359 getSourceVectorType().getShape(), reductionMask, 360 getSourceVectorType().getElementType()); 361 // TODO: update to support 0-d vectors when available. 362 if (targetType != getDestType()) 363 return emitError("invalid output vector type: ") 364 << getDestType() << " (expected: " << targetType << ")"; 365 return success(); 366 } 367 368 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) { 369 // Single parallel dim, this is a noop. 370 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) 371 return source(); 372 return {}; 373 } 374 375 //===----------------------------------------------------------------------===// 376 // ReductionOp 377 //===----------------------------------------------------------------------===// 378 379 LogicalResult ReductionOp::verify() { 380 // Verify for 1-D vector. 381 int64_t rank = getVectorType().getRank(); 382 if (rank != 1) 383 return emitOpError("unsupported reduction rank: ") << rank; 384 385 // Verify supported reduction kind. 386 StringRef strKind = kind(); 387 auto maybeKind = symbolizeCombiningKind(strKind); 388 if (!maybeKind) 389 return emitOpError("unknown reduction kind: ") << strKind; 390 391 Type eltType = dest().getType(); 392 if (!isSupportedCombiningKind(*maybeKind, eltType)) 393 return emitOpError("unsupported reduction type '") 394 << eltType << "' for kind '" << strKind << "'"; 395 396 // Verify optional accumulator. 397 if (!acc().empty()) { 398 if (strKind != "add" && strKind != "mul") 399 return emitOpError("no accumulator for reduction kind: ") << strKind; 400 if (!eltType.isa<FloatType>()) 401 return emitOpError("no accumulator for type: ") << eltType; 402 } 403 404 return success(); 405 } 406 407 ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { 408 SmallVector<OpAsmParser::OperandType, 2> operandsInfo; 409 Type redType; 410 Type resType; 411 Attribute attr; 412 if (parser.parseAttribute(attr, "kind", result.attributes) || 413 parser.parseComma() || parser.parseOperandList(operandsInfo) || 414 parser.parseColonType(redType) || 415 parser.parseKeywordType("into", resType) || 416 (!operandsInfo.empty() && 417 parser.resolveOperand(operandsInfo[0], redType, result.operands)) || 418 (operandsInfo.size() > 1 && 419 parser.resolveOperand(operandsInfo[1], resType, result.operands)) || 420 parser.addTypeToList(resType, result.types)) 421 return failure(); 422 if (operandsInfo.empty() || operandsInfo.size() > 2) 423 return parser.emitError(parser.getNameLoc(), 424 "unsupported number of operands"); 425 return success(); 426 } 427 428 void ReductionOp::print(OpAsmPrinter &p) { 429 p << " \"" << kind() << "\", " << vector(); 430 if (!acc().empty()) 431 p << ", " << acc(); 432 p << " : " << vector().getType() << " into " << dest().getType(); 433 } 434 435 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, 436 OpBuilder &builder, Location loc, 437 Value vector) { 438 Type scalarType = vector.getType().cast<ShapedType>().getElementType(); 439 switch (op) { 440 case arith::AtomicRMWKind::addf: 441 case arith::AtomicRMWKind::addi: 442 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 443 builder.getStringAttr("add"), 444 vector, ValueRange{}); 445 case arith::AtomicRMWKind::mulf: 446 case arith::AtomicRMWKind::muli: 447 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 448 builder.getStringAttr("mul"), 449 vector, ValueRange{}); 450 case arith::AtomicRMWKind::minf: 451 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 452 builder.getStringAttr("minf"), 453 vector, ValueRange{}); 454 case arith::AtomicRMWKind::mins: 455 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 456 builder.getStringAttr("minsi"), 457 vector, ValueRange{}); 458 case arith::AtomicRMWKind::minu: 459 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 460 builder.getStringAttr("minui"), 461 vector, ValueRange{}); 462 case arith::AtomicRMWKind::maxf: 463 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 464 builder.getStringAttr("maxf"), 465 vector, ValueRange{}); 466 case arith::AtomicRMWKind::maxs: 467 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 468 builder.getStringAttr("maxsi"), 469 vector, ValueRange{}); 470 case arith::AtomicRMWKind::maxu: 471 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 472 builder.getStringAttr("maxui"), 473 vector, ValueRange{}); 474 // TODO: Add remaining reduction operations. 475 default: 476 (void)emitOptionalError(loc, "Reduction operation type not supported"); 477 break; 478 } 479 return nullptr; 480 } 481 482 //===----------------------------------------------------------------------===// 483 // ContractionOp 484 //===----------------------------------------------------------------------===// 485 486 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 487 Value lhs, Value rhs, Value acc, 488 ArrayRef<ArrayRef<AffineExpr>> indexingExprs, 489 ArrayRef<StringRef> iteratorTypes) { 490 result.addOperands({lhs, rhs, acc}); 491 result.addTypes(acc.getType()); 492 result.addAttribute(getIndexingMapsAttrName(), 493 builder.getAffineMapArrayAttr( 494 AffineMap::inferFromExprList(indexingExprs))); 495 result.addAttribute(getIteratorTypesAttrName(), 496 builder.getStrArrayAttr(iteratorTypes)); 497 } 498 499 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 500 Value lhs, Value rhs, Value acc, 501 ArrayAttr indexingMaps, 502 ArrayAttr iteratorTypes) { 503 result.addOperands({lhs, rhs, acc}); 504 result.addTypes(acc.getType()); 505 result.addAttribute(getIndexingMapsAttrName(), indexingMaps); 506 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); 507 result.addAttribute(ContractionOp::getKindAttrName(), 508 CombiningKindAttr::get(ContractionOp::getDefaultKind(), 509 builder.getContext())); 510 } 511 512 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { 513 OpAsmParser::OperandType lhsInfo; 514 OpAsmParser::OperandType rhsInfo; 515 OpAsmParser::OperandType accInfo; 516 SmallVector<OpAsmParser::OperandType, 2> masksInfo; 517 SmallVector<Type, 2> types; 518 Type resultType; 519 auto loc = parser.getCurrentLocation(); 520 DictionaryAttr dictAttr; 521 // TODO: Unify linalg op attribute parsing. 522 if (parser.parseAttribute(dictAttr, "_", result.attributes) || 523 parser.parseOperand(lhsInfo) || parser.parseComma() || 524 parser.parseOperand(rhsInfo) || parser.parseComma() || 525 parser.parseOperand(accInfo) || 526 parser.parseTrailingOperandList(masksInfo) || 527 parser.parseOptionalAttrDict(result.attributes) || 528 parser.parseColonTypeList(types) || 529 parser.parseKeywordType("into", resultType) || 530 parser.resolveOperand(lhsInfo, types[0], result.operands) || 531 parser.resolveOperand(rhsInfo, types[1], result.operands) || 532 parser.resolveOperand(accInfo, resultType, result.operands) || 533 parser.addTypeToList(resultType, result.types)) 534 return failure(); 535 result.attributes.assign(dictAttr.getValue().begin(), 536 dictAttr.getValue().end()); 537 if (!result.attributes.get(ContractionOp::getKindAttrName())) { 538 result.addAttribute(ContractionOp::getKindAttrName(), 539 CombiningKindAttr::get(ContractionOp::getDefaultKind(), 540 result.getContext())); 541 } 542 if (masksInfo.empty()) 543 return success(); 544 if (masksInfo.size() != 2) 545 return parser.emitError(parser.getNameLoc(), 546 "expected zero or exactly 2 vector mask operands"); 547 auto lhsType = types[0].cast<VectorType>(); 548 auto rhsType = types[1].cast<VectorType>(); 549 auto maskElementType = parser.getBuilder().getI1Type(); 550 std::array<Type, 2> maskTypes = { 551 VectorType::Builder(lhsType).setElementType(maskElementType), 552 VectorType::Builder(rhsType).setElementType(maskElementType)}; 553 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) 554 return failure(); 555 return success(); 556 } 557 558 void ContractionOp::print(OpAsmPrinter &p) { 559 // TODO: Unify printing code with linalg ops. 560 auto attrNames = getTraitAttrNames(); 561 llvm::StringSet<> traitAttrsSet; 562 traitAttrsSet.insert(attrNames.begin(), attrNames.end()); 563 SmallVector<NamedAttribute, 8> attrs; 564 for (auto attr : (*this)->getAttrs()) 565 if (traitAttrsSet.count(attr.getName().strref()) > 0) 566 attrs.push_back(attr); 567 568 auto dictAttr = DictionaryAttr::get(getContext(), attrs); 569 p << " " << dictAttr << " " << lhs() << ", "; 570 p << rhs() << ", " << acc(); 571 if (masks().size() == 2) 572 p << ", " << masks(); 573 574 p.printOptionalAttrDict((*this)->getAttrs(), attrNames); 575 p << " : " << lhs().getType() << ", " << rhs().getType() << " into " 576 << getResultType(); 577 } 578 579 static bool verifyDimMap(VectorType lhsType, VectorType rhsType, 580 const std::vector<std::pair<int64_t, int64_t>> &map) { 581 for (auto &dimPair : map) { 582 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || 583 dimPair.second < 0 || dimPair.second >= rhsType.getRank() || 584 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) 585 return false; 586 } 587 return true; 588 } 589 590 static LogicalResult verifyOutputShape( 591 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, 592 Type resType, 593 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, 594 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { 595 DenseSet<int64_t> lhsContractingDimSet; 596 DenseSet<int64_t> rhsContractingDimSet; 597 for (auto &dimPair : contractingDimMap) { 598 lhsContractingDimSet.insert(dimPair.first); 599 rhsContractingDimSet.insert(dimPair.second); 600 } 601 DenseSet<int64_t> rhsBatchDimSet; 602 for (auto &dimPair : batchDimMap) 603 rhsBatchDimSet.insert(dimPair.second); 604 605 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. 606 SmallVector<int64_t, 4> expectedResultDims; 607 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { 608 if (lhsContractingDimSet.count(i) > 0) 609 continue; 610 expectedResultDims.push_back(lhsType.getDimSize(i)); 611 } 612 613 // Add free dimensions from 'rhsType' to 'expectedResultDims'. 614 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { 615 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) 616 continue; 617 expectedResultDims.push_back(rhsType.getDimSize(i)); 618 } 619 620 // Verify 'expectedResultDims'. 621 if (expectedResultDims.empty()) { 622 // No batch or free dimension implies a scalar result. 623 if (resType.isa<VectorType>() || accType.isa<VectorType>()) 624 return op.emitOpError("invalid accumulator/result vector shape"); 625 } else { 626 // At least one batch or free dimension implies a vector result. 627 auto resVectorType = resType.dyn_cast<VectorType>(); 628 auto accVectorType = accType.dyn_cast<VectorType>(); 629 if (!resVectorType || !accVectorType) 630 return op.emitOpError("invalid accumulator/result vector shape"); 631 632 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector 633 // types fully define the result vector type. This assumes the affine maps 634 // are well-formed, which must have been verified already. 635 MLIRContext *ctx = op.getContext(); 636 AffineMap lhsMap = op.getIndexingMaps()[0]; 637 AffineMap rhsMap = op.getIndexingMaps()[1]; 638 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); 639 for (auto pair : 640 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { 641 VectorType v = pair.first; 642 auto map = pair.second; 643 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { 644 unsigned pos = map.getDimPosition(idx); 645 if (!extents[pos]) 646 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); 647 } 648 } 649 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) && 650 "expected extent along all dimensions."); 651 652 AffineMap resMap = op.getIndexingMaps()[2]; 653 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), 654 /*symCount=*/0, extents, ctx); 655 // Compose the resMap with the extentsMap, which is a constant map. 656 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); 657 assert(llvm::all_of( 658 expectedMap.getResults(), 659 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && 660 "expected constant extent along all dimensions."); 661 // Extract the expected shape and build the type. 662 auto expectedShape = llvm::to_vector<4>( 663 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { 664 return e.cast<AffineConstantExpr>().getValue(); 665 })); 666 auto expected = 667 VectorType::get(expectedShape, resVectorType.getElementType()); 668 if (resVectorType != expected || accVectorType != expected) 669 return op.emitOpError( 670 "invalid accumulator/result vector shape, expected: ") 671 << expected; 672 } 673 return success(); 674 } 675 676 LogicalResult ContractionOp::verify() { 677 auto lhsType = getLhsType(); 678 auto rhsType = getRhsType(); 679 auto accType = getAccType(); 680 auto resType = getResultType(); 681 682 // Verify that an indexing map was specified for each vector operand. 683 if (indexing_maps().size() != 3) 684 return emitOpError("expected an indexing map for each vector operand"); 685 686 // Verify that each index map has 'numIterators' inputs, no symbols, and 687 // that the number of map outputs equals the rank of its associated 688 // vector operand. 689 unsigned numIterators = iterator_types().getValue().size(); 690 for (const auto &it : llvm::enumerate(indexing_maps())) { 691 auto index = it.index(); 692 auto map = it.value().cast<AffineMapAttr>().getValue(); 693 if (map.getNumSymbols() != 0) 694 return emitOpError("expected indexing map ") 695 << index << " to have no symbols"; 696 auto vectorType = getOperand(index).getType().dyn_cast<VectorType>(); 697 unsigned rank = vectorType ? vectorType.getShape().size() : 0; 698 // Verify that the map has the right number of inputs, outputs, and indices. 699 // This also correctly accounts for (..) -> () for rank-0 results. 700 if (map.getNumDims() != numIterators) 701 return emitOpError("expected indexing map ") 702 << index << " to have " << numIterators << " number of inputs"; 703 if (map.getNumResults() != rank) 704 return emitOpError("expected indexing map ") 705 << index << " to have " << rank << " number of outputs"; 706 if (!map.isProjectedPermutation()) 707 return emitOpError("expected indexing map ") 708 << index << " to be a projected permutation of its inputs"; 709 } 710 711 auto contractingDimMap = getContractingDimMap(); 712 auto batchDimMap = getBatchDimMap(); 713 714 // Verify at least one contracting dimension pair was specified. 715 if (contractingDimMap.empty()) 716 return emitOpError("expected at least one contracting dimension pair"); 717 718 // Verify contracting dimension map was properly constructed. 719 if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) 720 return emitOpError("invalid contracting dimension map"); 721 722 // Verify batch dimension map was properly constructed. 723 if (!verifyDimMap(lhsType, rhsType, batchDimMap)) 724 return emitOpError("invalid batch dimension map"); 725 726 // Verify 'accType' and 'resType' shape. 727 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, 728 contractingDimMap, batchDimMap))) 729 return failure(); 730 731 // Verify that either two vector masks are set or none are set. 732 auto lhsMaskType = getLHSVectorMaskType(); 733 auto rhsMaskType = getRHSVectorMaskType(); 734 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) 735 return emitOpError("invalid number of vector masks specified"); 736 if (lhsMaskType && rhsMaskType) { 737 // Verify mask rank == argument rank. 738 if (lhsMaskType.getShape().size() != lhsType.getShape().size() || 739 rhsMaskType.getShape().size() != rhsType.getShape().size()) 740 return emitOpError("invalid vector mask rank"); 741 } 742 743 // Verify supported combining kind. 744 auto vectorType = resType.dyn_cast<VectorType>(); 745 auto elementType = vectorType ? vectorType.getElementType() : resType; 746 if (!isSupportedCombiningKind(kind(), elementType)) 747 return emitOpError("unsupported contraction type"); 748 749 return success(); 750 } 751 752 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() { 753 static constexpr StringRef names[3] = {getIndexingMapsAttrName(), 754 getIteratorTypesAttrName(), 755 ContractionOp::getKindAttrName()}; 756 return llvm::makeArrayRef(names); 757 } 758 759 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { 760 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) 761 if (targetExpr == map.getResult(i)) 762 return i; 763 return -1; 764 } 765 766 static std::vector<std::pair<int64_t, int64_t>> 767 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, 768 StringRef targetIteratorTypeName, MLIRContext *context) { 769 std::vector<std::pair<int64_t, int64_t>> dimMap; 770 for (const auto &it : llvm::enumerate(iteratorTypes)) { 771 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 772 if (iteratorTypeName != targetIteratorTypeName) 773 continue; 774 // Search lhs/rhs map results for 'targetExpr'. 775 auto targetExpr = getAffineDimExpr(it.index(), context); 776 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); 777 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); 778 if (lhsDim >= 0 && rhsDim >= 0) 779 dimMap.emplace_back(lhsDim, rhsDim); 780 } 781 return dimMap; 782 } 783 784 void ContractionOp::getIterationBounds( 785 SmallVectorImpl<int64_t> &iterationBounds) { 786 auto lhsShape = getLhsType().getShape(); 787 auto resVectorType = getResultType().dyn_cast<VectorType>(); 788 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 789 SmallVector<int64_t, 2> iterationShape; 790 for (const auto &it : llvm::enumerate(iterator_types())) { 791 // Search lhs/rhs map results for 'targetExpr'. 792 auto targetExpr = getAffineDimExpr(it.index(), getContext()); 793 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 794 if (iteratorTypeName == getReductionIteratorTypeName()) { 795 // Get reduction dim size from lhs shape (same size in rhsShape). 796 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); 797 assert(lhsDimIndex >= 0); 798 iterationBounds.push_back(lhsShape[lhsDimIndex]); 799 continue; 800 } 801 // Get parallel dimension size from result shape. 802 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); 803 assert(resDimIndex >= 0); 804 assert(resVectorType != nullptr); 805 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); 806 } 807 } 808 809 void ContractionOp::getIterationIndexMap( 810 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { 811 unsigned numMaps = indexing_maps().getValue().size(); 812 iterationIndexMap.resize(numMaps); 813 for (const auto &it : llvm::enumerate(indexing_maps())) { 814 auto index = it.index(); 815 auto map = it.value().cast<AffineMapAttr>().getValue(); 816 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 817 auto dim = map.getResult(i).cast<AffineDimExpr>(); 818 iterationIndexMap[index][dim.getPosition()] = i; 819 } 820 } 821 } 822 823 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { 824 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 825 return getDimMap(indexingMaps, iterator_types(), 826 getReductionIteratorTypeName(), getContext()); 827 } 828 829 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { 830 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 831 return getDimMap(indexingMaps, iterator_types(), 832 getParallelIteratorTypeName(), getContext()); 833 } 834 835 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { 836 return llvm::to_vector<4>( 837 llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) { 838 return mapAttr.cast<AffineMapAttr>().getValue(); 839 })); 840 } 841 842 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { 843 SmallVector<int64_t, 4> shape; 844 getIterationBounds(shape); 845 return shape; 846 } 847 848 /// Return a fused vector::ContractionOp which represents a patterns such as: 849 /// 850 /// ```mlir 851 /// %c0 = vector.constant 0: ... 852 /// %c = vector.contract %a, %b, %c0: ... 853 /// %e = add %c, %d: ... 854 /// ``` 855 /// 856 /// by: 857 /// 858 /// ```mlir 859 /// %e = vector.contract %a, %b, %d: ... 860 /// ``` 861 /// 862 /// Return null if the canonicalization does not apply. 863 // TODO: This should be a folding of Add into Contract in core but while they 864 // live in different dialects, it is not possible without unnatural 865 // dependencies. 866 template <typename AddOpType> 867 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { 868 using OpRewritePattern<AddOpType>::OpRewritePattern; 869 870 LogicalResult matchAndRewrite(AddOpType addOp, 871 PatternRewriter &rewriter) const override { 872 auto canonicalize = [&](Value maybeContraction, 873 Value otherOperand) -> vector::ContractionOp { 874 vector::ContractionOp contractionOp = 875 dyn_cast_or_null<vector::ContractionOp>( 876 maybeContraction.getDefiningOp()); 877 if (!contractionOp) 878 return vector::ContractionOp(); 879 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>( 880 contractionOp.acc().getDefiningOp())) { 881 if (maybeZero.getValue() == 882 rewriter.getZeroAttr(contractionOp.acc().getType())) { 883 BlockAndValueMapping bvm; 884 bvm.map(contractionOp.acc(), otherOperand); 885 auto newContraction = 886 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm)); 887 rewriter.replaceOp(addOp, newContraction.getResult()); 888 return newContraction; 889 } 890 } 891 return vector::ContractionOp(); 892 }; 893 894 Value a = addOp->getOperand(0), b = addOp->getOperand(1); 895 vector::ContractionOp contract = canonicalize(a, b); 896 contract = contract ? contract : canonicalize(b, a); 897 return contract ? success() : failure(); 898 } 899 }; 900 901 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, 902 MLIRContext *context) { 903 results.add<CanonicalizeContractAdd<arith::AddIOp>, 904 CanonicalizeContractAdd<arith::AddFOp>>(context); 905 } 906 907 //===----------------------------------------------------------------------===// 908 // ExtractElementOp 909 //===----------------------------------------------------------------------===// 910 911 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 912 Value source) { 913 result.addOperands({source}); 914 result.addTypes(source.getType().cast<VectorType>().getElementType()); 915 } 916 917 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 918 Value source, Value position) { 919 result.addOperands({source, position}); 920 result.addTypes(source.getType().cast<VectorType>().getElementType()); 921 } 922 923 LogicalResult vector::ExtractElementOp::verify() { 924 VectorType vectorType = getVectorType(); 925 if (vectorType.getRank() == 0) { 926 if (position()) 927 return emitOpError("expected position to be empty with 0-D vector"); 928 return success(); 929 } 930 if (vectorType.getRank() != 1) 931 return emitOpError("unexpected >1 vector rank"); 932 if (!position()) 933 return emitOpError("expected position for 1-D vector"); 934 return success(); 935 } 936 937 //===----------------------------------------------------------------------===// 938 // ExtractOp 939 //===----------------------------------------------------------------------===// 940 941 static Type inferExtractOpResultType(VectorType vectorType, 942 ArrayAttr position) { 943 if (static_cast<int64_t>(position.size()) == vectorType.getRank()) 944 return vectorType.getElementType(); 945 return VectorType::get(vectorType.getShape().drop_front(position.size()), 946 vectorType.getElementType()); 947 } 948 949 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 950 Value source, ArrayRef<int64_t> position) { 951 result.addOperands(source); 952 auto positionAttr = getVectorSubscriptAttr(builder, position); 953 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(), 954 positionAttr)); 955 result.addAttribute(getPositionAttrName(), positionAttr); 956 } 957 958 // Convenience builder which assumes the values are constant indices. 959 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 960 Value source, ValueRange position) { 961 SmallVector<int64_t, 4> positionConstants = 962 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 963 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 964 })); 965 build(builder, result, source, positionConstants); 966 } 967 968 void vector::ExtractOp::print(OpAsmPrinter &p) { 969 p << " " << vector() << position(); 970 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 971 p << " : " << vector().getType(); 972 } 973 974 ParseResult vector::ExtractOp::parse(OpAsmParser &parser, 975 OperationState &result) { 976 SMLoc attributeLoc, typeLoc; 977 NamedAttrList attrs; 978 OpAsmParser::OperandType vector; 979 Type type; 980 Attribute attr; 981 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) || 982 parser.parseAttribute(attr, "position", attrs) || 983 parser.parseOptionalAttrDict(attrs) || 984 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type)) 985 return failure(); 986 987 auto vectorType = type.dyn_cast<VectorType>(); 988 if (!vectorType) 989 return parser.emitError(typeLoc, "expected vector type"); 990 991 auto positionAttr = attr.dyn_cast<ArrayAttr>(); 992 if (!positionAttr || 993 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank()) 994 return parser.emitError( 995 attributeLoc, 996 "expected position attribute of rank smaller than vector rank"); 997 998 Type resType = inferExtractOpResultType(vectorType, positionAttr); 999 result.attributes = attrs; 1000 return failure(parser.resolveOperand(vector, type, result.operands) || 1001 parser.addTypeToList(resType, result.types)); 1002 } 1003 1004 LogicalResult vector::ExtractOp::verify() { 1005 auto positionAttr = position().getValue(); 1006 if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank())) 1007 return emitOpError( 1008 "expected position attribute of rank smaller than vector rank"); 1009 for (const auto &en : llvm::enumerate(positionAttr)) { 1010 auto attr = en.value().dyn_cast<IntegerAttr>(); 1011 if (!attr || attr.getInt() < 0 || 1012 attr.getInt() >= getVectorType().getDimSize(en.index())) 1013 return emitOpError("expected position attribute #") 1014 << (en.index() + 1) 1015 << " to be a non-negative integer smaller than the corresponding " 1016 "vector dimension"; 1017 } 1018 return success(); 1019 } 1020 1021 template <typename IntType> 1022 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 1023 return llvm::to_vector<4>(llvm::map_range( 1024 arrayAttr.getAsRange<IntegerAttr>(), 1025 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 1026 } 1027 1028 /// Fold the result of chains of ExtractOp in place by simply concatenating the 1029 /// positions. 1030 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { 1031 if (!extractOp.vector().getDefiningOp<ExtractOp>()) 1032 return failure(); 1033 1034 SmallVector<int64_t, 4> globalPosition; 1035 ExtractOp currentOp = extractOp; 1036 auto extrPos = extractVector<int64_t>(currentOp.position()); 1037 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1038 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) { 1039 currentOp = nextOp; 1040 auto extrPos = extractVector<int64_t>(currentOp.position()); 1041 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1042 } 1043 extractOp.setOperand(currentOp.vector()); 1044 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1045 OpBuilder b(extractOp.getContext()); 1046 std::reverse(globalPosition.begin(), globalPosition.end()); 1047 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1048 b.getI64ArrayAttr(globalPosition)); 1049 return success(); 1050 } 1051 1052 namespace { 1053 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. 1054 /// Walk back a chain of InsertOp/TransposeOp until we hit a match. 1055 /// Compose TransposeOp permutations as we walk back. 1056 /// This helper class keeps an updated extraction position `extractPosition` 1057 /// with extra trailing sentinels. 1058 /// The sentinels encode the internal transposition status of the result vector. 1059 /// As we iterate, extractPosition is permuted and updated. 1060 class ExtractFromInsertTransposeChainState { 1061 public: 1062 ExtractFromInsertTransposeChainState(ExtractOp e); 1063 1064 /// Iterate over producing insert and transpose ops until we find a fold. 1065 Value fold(); 1066 1067 private: 1068 /// Return true if the vector at position `a` is contained within the vector 1069 /// at position `b`. Under insert/extract semantics, this is the same as `a` 1070 /// is a prefix of `b`. 1071 template <typename ContainerA, typename ContainerB> 1072 bool isContainedWithin(const ContainerA &a, const ContainerB &b) { 1073 return a.size() <= b.size() && 1074 std::equal(a.begin(), a.begin() + a.size(), b.begin()); 1075 } 1076 1077 /// Return true if the vector at position `a` intersects the vector at 1078 /// position `b`. Under insert/extract semantics, this is the same as equality 1079 /// of all entries of `a` that are >=0 with the corresponding entries of b. 1080 /// Comparison is on the common prefix (i.e. zip). 1081 template <typename ContainerA, typename ContainerB> 1082 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { 1083 for (auto it : llvm::zip(a, b)) { 1084 if (std::get<0>(it) < 0 || std::get<0>(it) < 0) 1085 continue; 1086 if (std::get<0>(it) != std::get<1>(it)) 1087 return false; 1088 } 1089 return true; 1090 } 1091 1092 /// Folding is only possible in the absence of an internal permutation in the 1093 /// result vector. 1094 bool canFold() { 1095 return (sentinels == 1096 makeArrayRef(extractPosition).drop_front(extractedRank)); 1097 } 1098 1099 // Helper to get the next defining op of interest. 1100 void updateStateForNextIteration(Value v) { 1101 nextInsertOp = v.getDefiningOp<vector::InsertOp>(); 1102 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>(); 1103 }; 1104 1105 // Case 1. If we hit a transpose, just compose the map and iterate. 1106 // Invariant: insert + transpose do not change rank, we can always compose. 1107 LogicalResult handleTransposeOp(); 1108 1109 // Case 2: the insert position matches extractPosition exactly, early return. 1110 LogicalResult handleInsertOpWithMatchingPos(Value &res); 1111 1112 /// Case 3: if the insert position is a prefix of extractPosition, extract a 1113 /// portion of the source of the insert. 1114 /// Example: 1115 /// ``` 1116 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> 1117 /// // extractPosition == [1, 2, 3] 1118 /// %ext = vector.extract %ins[1, 0]: vector<3x4x5> 1119 /// // can fold to vector.extract %source[0, 3] 1120 /// %ext = vector.extract %source[3]: vector<5x6> 1121 /// ``` 1122 /// To traverse through %source, we need to set the leading dims to 0 and 1123 /// drop the extra leading dims. 1124 /// This method updates the internal state. 1125 LogicalResult handleInsertOpWithPrefixPos(Value &res); 1126 1127 /// Try to fold in place to extract(source, extractPosition) and return the 1128 /// folded result. Return null if folding is not possible (e.g. due to an 1129 /// internal tranposition in the result). 1130 Value tryToFoldExtractOpInPlace(Value source); 1131 1132 ExtractOp extractOp; 1133 int64_t vectorRank; 1134 int64_t extractedRank; 1135 1136 InsertOp nextInsertOp; 1137 TransposeOp nextTransposeOp; 1138 1139 /// Sentinel values that encode the internal permutation status of the result. 1140 /// They are set to (-1, ... , -k) at the beginning and appended to 1141 /// `extractPosition`. 1142 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to 1143 /// ensure that there is no internal transposition. 1144 /// Internal transposition cannot be accounted for with a folding pattern. 1145 // TODO: We could relax the internal transposition with an extra transposition 1146 // operation in a future canonicalizer. 1147 SmallVector<int64_t> sentinels; 1148 SmallVector<int64_t> extractPosition; 1149 }; 1150 } // namespace 1151 1152 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( 1153 ExtractOp e) 1154 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), 1155 extractedRank(extractOp.position().size()) { 1156 assert(vectorRank >= extractedRank && "extracted pos overflow"); 1157 sentinels.reserve(vectorRank - extractedRank); 1158 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) 1159 sentinels.push_back(-(i + 1)); 1160 extractPosition = extractVector<int64_t>(extractOp.position()); 1161 llvm::append_range(extractPosition, sentinels); 1162 } 1163 1164 // Case 1. If we hit a transpose, just compose the map and iterate. 1165 // Invariant: insert + transpose do not change rank, we can always compose. 1166 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { 1167 if (!nextTransposeOp) 1168 return failure(); 1169 auto permutation = extractVector<unsigned>(nextTransposeOp.transp()); 1170 AffineMap m = inversePermutation( 1171 AffineMap::getPermutationMap(permutation, extractOp.getContext())); 1172 extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); 1173 return success(); 1174 } 1175 1176 // Case 2: the insert position matches extractPosition exactly, early return. 1177 LogicalResult 1178 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( 1179 Value &res) { 1180 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1181 if (makeArrayRef(insertedPos) != 1182 llvm::makeArrayRef(extractPosition).take_front(extractedRank)) 1183 return failure(); 1184 // Case 2.a. early-exit fold. 1185 res = nextInsertOp.source(); 1186 // Case 2.b. if internal transposition is present, canFold will be false. 1187 return success(); 1188 } 1189 1190 /// Case 3: if inserted position is a prefix of extractPosition, 1191 /// extract a portion of the source of the insertion. 1192 /// This method updates the internal state. 1193 LogicalResult 1194 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { 1195 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1196 if (!isContainedWithin(insertedPos, extractPosition)) 1197 return failure(); 1198 // Set leading dims to zero. 1199 std::fill_n(extractPosition.begin(), insertedPos.size(), 0); 1200 // Drop extra leading dims. 1201 extractPosition.erase(extractPosition.begin(), 1202 extractPosition.begin() + insertedPos.size()); 1203 extractedRank = extractPosition.size() - sentinels.size(); 1204 // Case 3.a. early-exit fold (break and delegate to post-while path). 1205 res = nextInsertOp.source(); 1206 // Case 3.b. if internal transposition is present, canFold will be false. 1207 return success(); 1208 } 1209 1210 /// Try to fold in place to extract(source, extractPosition) and return the 1211 /// folded result. Return null if folding is not possible (e.g. due to an 1212 /// internal tranposition in the result). 1213 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( 1214 Value source) { 1215 // If we can't fold (either internal transposition, or nothing to fold), bail. 1216 bool nothingToFold = (source == extractOp.vector()); 1217 if (nothingToFold || !canFold()) 1218 return Value(); 1219 // Otherwise, fold by updating the op inplace and return its result. 1220 OpBuilder b(extractOp.getContext()); 1221 extractOp->setAttr( 1222 extractOp.positionAttrName(), 1223 b.getI64ArrayAttr( 1224 makeArrayRef(extractPosition).take_front(extractedRank))); 1225 extractOp.vectorMutable().assign(source); 1226 return extractOp.getResult(); 1227 } 1228 1229 /// Iterate over producing insert and transpose ops until we find a fold. 1230 Value ExtractFromInsertTransposeChainState::fold() { 1231 Value valueToExtractFrom = extractOp.vector(); 1232 updateStateForNextIteration(valueToExtractFrom); 1233 while (nextInsertOp || nextTransposeOp) { 1234 // Case 1. If we hit a transpose, just compose the map and iterate. 1235 // Invariant: insert + transpose do not change rank, we can always compose. 1236 if (succeeded(handleTransposeOp())) { 1237 valueToExtractFrom = nextTransposeOp.vector(); 1238 updateStateForNextIteration(valueToExtractFrom); 1239 continue; 1240 } 1241 1242 Value result; 1243 // Case 2: the position match exactly. 1244 if (succeeded(handleInsertOpWithMatchingPos(result))) 1245 return result; 1246 1247 // Case 3: if the inserted position is a prefix of extractPosition, we can 1248 // just extract a portion of the source of the insert. 1249 if (succeeded(handleInsertOpWithPrefixPos(result))) 1250 return tryToFoldExtractOpInPlace(result); 1251 1252 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel 1253 // values. This is a more difficult case and we bail. 1254 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1255 if (isContainedWithin(extractPosition, insertedPos) || 1256 intersectsWhereNonNegative(extractPosition, insertedPos)) 1257 return Value(); 1258 1259 // Case 5: No intersection, we forward the extract to insertOp.dest(). 1260 valueToExtractFrom = nextInsertOp.dest(); 1261 updateStateForNextIteration(valueToExtractFrom); 1262 } 1263 // If after all this we can fold, go for it. 1264 return tryToFoldExtractOpInPlace(valueToExtractFrom); 1265 } 1266 1267 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. 1268 static Value foldExtractFromBroadcast(ExtractOp extractOp) { 1269 Operation *defOp = extractOp.vector().getDefiningOp(); 1270 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1271 return Value(); 1272 Value source = defOp->getOperand(0); 1273 if (extractOp.getType() == source.getType()) 1274 return source; 1275 auto getRank = [](Type type) { 1276 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1277 }; 1278 unsigned broadcastSrcRank = getRank(source.getType()); 1279 unsigned extractResultRank = getRank(extractOp.getType()); 1280 if (extractResultRank < broadcastSrcRank) { 1281 auto extractPos = extractVector<int64_t>(extractOp.position()); 1282 unsigned rankDiff = broadcastSrcRank - extractResultRank; 1283 extractPos.erase( 1284 extractPos.begin(), 1285 std::next(extractPos.begin(), extractPos.size() - rankDiff)); 1286 extractOp.setOperand(source); 1287 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1288 OpBuilder b(extractOp.getContext()); 1289 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1290 b.getI64ArrayAttr(extractPos)); 1291 return extractOp.getResult(); 1292 } 1293 return Value(); 1294 } 1295 1296 // Fold extractOp with source coming from ShapeCast op. 1297 static Value foldExtractFromShapeCast(ExtractOp extractOp) { 1298 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); 1299 if (!shapeCastOp) 1300 return Value(); 1301 // Get the nth dimension size starting from lowest dimension. 1302 auto getDimReverse = [](VectorType type, int64_t n) { 1303 return type.getShape().take_back(n + 1).front(); 1304 }; 1305 int64_t destinationRank = 1306 extractOp.getType().isa<VectorType>() 1307 ? extractOp.getType().cast<VectorType>().getRank() 1308 : 0; 1309 if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) 1310 return Value(); 1311 if (destinationRank > 0) { 1312 auto destinationType = extractOp.getResult().getType().cast<VectorType>(); 1313 for (int64_t i = 0; i < destinationRank; i++) { 1314 // The lowest dimension of of the destination must match the lowest 1315 // dimension of the shapecast op source. 1316 // TODO: This case could be support in a canonicalization pattern. 1317 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != 1318 getDimReverse(destinationType, i)) 1319 return Value(); 1320 } 1321 } 1322 // Extract the strides associated with the extract op vector source. Then use 1323 // this to calculate a linearized position for the extract. 1324 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1325 std::reverse(extractedPos.begin(), extractedPos.end()); 1326 SmallVector<int64_t, 4> strides; 1327 int64_t stride = 1; 1328 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { 1329 strides.push_back(stride); 1330 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); 1331 } 1332 1333 int64_t position = linearize(extractedPos, strides); 1334 // Then extract the strides associated to the shapeCast op vector source and 1335 // delinearize the position using those strides. 1336 SmallVector<int64_t, 4> newStrides; 1337 int64_t numDimension = 1338 shapeCastOp.getSourceVectorType().getRank() - destinationRank; 1339 stride = 1; 1340 for (int64_t i = 0; i < numDimension; i++) { 1341 newStrides.push_back(stride); 1342 stride *= 1343 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); 1344 } 1345 std::reverse(newStrides.begin(), newStrides.end()); 1346 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position); 1347 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1348 OpBuilder b(extractOp.getContext()); 1349 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1350 b.getI64ArrayAttr(newPosition)); 1351 extractOp.setOperand(shapeCastOp.source()); 1352 return extractOp.getResult(); 1353 } 1354 1355 /// Fold an ExtractOp from ExtractStridedSliceOp. 1356 static Value foldExtractFromExtractStrided(ExtractOp extractOp) { 1357 auto extractStridedSliceOp = 1358 extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>(); 1359 if (!extractStridedSliceOp) 1360 return Value(); 1361 // Return if 'extractStridedSliceOp' has non-unit strides. 1362 if (extractStridedSliceOp.hasNonUnitStrides()) 1363 return Value(); 1364 1365 // Trim offsets for dimensions fully extracted. 1366 auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets()); 1367 while (!sliceOffsets.empty()) { 1368 size_t lastOffset = sliceOffsets.size() - 1; 1369 if (sliceOffsets.back() != 0 || 1370 extractStridedSliceOp.getType().getDimSize(lastOffset) != 1371 extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) 1372 break; 1373 sliceOffsets.pop_back(); 1374 } 1375 unsigned destinationRank = 0; 1376 if (auto vecType = extractOp.getType().dyn_cast<VectorType>()) 1377 destinationRank = vecType.getRank(); 1378 // The dimensions of the result need to be untouched by the 1379 // extractStridedSlice op. 1380 if (destinationRank > 1381 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) 1382 return Value(); 1383 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1384 assert(extractedPos.size() >= sliceOffsets.size()); 1385 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) 1386 extractedPos[i] = extractedPos[i] + sliceOffsets[i]; 1387 extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); 1388 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1389 OpBuilder b(extractOp.getContext()); 1390 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1391 b.getI64ArrayAttr(extractedPos)); 1392 return extractOp.getResult(); 1393 } 1394 1395 /// Fold extract_op fed from a chain of insertStridedSlice ops. 1396 static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { 1397 int64_t destinationRank = op.getType().isa<VectorType>() 1398 ? op.getType().cast<VectorType>().getRank() 1399 : 0; 1400 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 1401 while (insertOp) { 1402 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - 1403 insertOp.getSourceVectorType().getRank(); 1404 if (destinationRank > insertOp.getSourceVectorType().getRank()) 1405 return Value(); 1406 auto insertOffsets = extractVector<int64_t>(insertOp.offsets()); 1407 auto extractOffsets = extractVector<int64_t>(op.position()); 1408 1409 if (llvm::any_of(insertOp.strides(), [](Attribute attr) { 1410 return attr.cast<IntegerAttr>().getInt() != 1; 1411 })) 1412 return Value(); 1413 bool disjoint = false; 1414 SmallVector<int64_t, 4> offsetDiffs; 1415 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 1416 int64_t start = insertOffsets[dim]; 1417 int64_t size = 1418 (dim < insertRankDiff) 1419 ? 1 1420 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); 1421 int64_t end = start + size; 1422 int64_t offset = extractOffsets[dim]; 1423 // Check if the start of the extract offset is in the interval inserted. 1424 if (start <= offset && offset < end) { 1425 if (dim >= insertRankDiff) 1426 offsetDiffs.push_back(offset - start); 1427 continue; 1428 } 1429 disjoint = true; 1430 break; 1431 } 1432 // The extract element chunk overlap with the vector inserted. 1433 if (!disjoint) { 1434 // If any of the inner dimensions are only partially inserted we have a 1435 // partial overlap. 1436 int64_t srcRankDiff = 1437 insertOp.getSourceVectorType().getRank() - destinationRank; 1438 for (int64_t i = 0; i < destinationRank; i++) { 1439 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != 1440 insertOp.getDestVectorType().getDimSize(i + srcRankDiff + 1441 insertRankDiff)) 1442 return Value(); 1443 } 1444 op.vectorMutable().assign(insertOp.source()); 1445 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1446 OpBuilder b(op.getContext()); 1447 op->setAttr(ExtractOp::getPositionAttrName(), 1448 b.getI64ArrayAttr(offsetDiffs)); 1449 return op.getResult(); 1450 } 1451 // If the chunk extracted is disjoint from the chunk inserted, keep 1452 // looking in the insert chain. 1453 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 1454 } 1455 return Value(); 1456 } 1457 1458 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { 1459 if (position().empty()) 1460 return vector(); 1461 if (succeeded(foldExtractOpFromExtractChain(*this))) 1462 return getResult(); 1463 if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) 1464 return res; 1465 if (auto res = foldExtractFromBroadcast(*this)) 1466 return res; 1467 if (auto res = foldExtractFromShapeCast(*this)) 1468 return res; 1469 if (auto val = foldExtractFromExtractStrided(*this)) 1470 return val; 1471 if (auto val = foldExtractStridedOpFromInsertChain(*this)) 1472 return val; 1473 return OpFoldResult(); 1474 } 1475 1476 namespace { 1477 1478 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. 1479 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { 1480 public: 1481 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1482 1483 LogicalResult matchAndRewrite(ExtractOp extractOp, 1484 PatternRewriter &rewriter) const override { 1485 Operation *defOp = extractOp.vector().getDefiningOp(); 1486 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1487 return failure(); 1488 Value source = defOp->getOperand(0); 1489 if (extractOp.getType() == source.getType()) 1490 return failure(); 1491 auto getRank = [](Type type) { 1492 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1493 }; 1494 unsigned broadcastSrcRank = getRank(source.getType()); 1495 unsigned extractResultRank = getRank(extractOp.getType()); 1496 // We only consider the case where the rank of the source is smaller than 1497 // the rank of the extract dst. The other cases are handled in the folding 1498 // patterns. 1499 if (extractResultRank <= broadcastSrcRank) 1500 return failure(); 1501 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1502 extractOp, extractOp.getType(), source); 1503 return success(); 1504 } 1505 }; 1506 1507 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. 1508 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> { 1509 public: 1510 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1511 1512 LogicalResult matchAndRewrite(ExtractOp extractOp, 1513 PatternRewriter &rewriter) const override { 1514 // Return if 'extractStridedSliceOp' operand is not defined by a 1515 // ConstantOp. 1516 auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>(); 1517 if (!constantOp) 1518 return failure(); 1519 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 1520 if (!dense) 1521 return failure(); 1522 Attribute newAttr = dense.getSplatValue<Attribute>(); 1523 if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>()) 1524 newAttr = DenseElementsAttr::get(vecDstType, newAttr); 1525 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); 1526 return success(); 1527 } 1528 }; 1529 1530 } // namespace 1531 1532 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, 1533 MLIRContext *context) { 1534 results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context); 1535 } 1536 1537 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 1538 SmallVectorImpl<int64_t> &results) { 1539 for (auto attr : arrayAttr) 1540 results.push_back(attr.cast<IntegerAttr>().getInt()); 1541 } 1542 1543 //===----------------------------------------------------------------------===// 1544 // ExtractMapOp 1545 //===----------------------------------------------------------------------===// 1546 1547 void ExtractMapOp::build(OpBuilder &builder, OperationState &result, 1548 Value vector, ValueRange ids, 1549 ArrayRef<int64_t> multiplicity, 1550 AffineMap permutationMap) { 1551 assert(ids.size() == multiplicity.size() && 1552 ids.size() == permutationMap.getNumResults()); 1553 assert(permutationMap.isProjectedPermutation()); 1554 VectorType type = vector.getType().cast<VectorType>(); 1555 SmallVector<int64_t, 4> newShape(type.getShape().begin(), 1556 type.getShape().end()); 1557 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) { 1558 AffineExpr expr = permutationMap.getResult(i); 1559 auto dim = expr.cast<AffineDimExpr>(); 1560 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i]; 1561 } 1562 VectorType resultType = VectorType::get(newShape, type.getElementType()); 1563 ExtractMapOp::build(builder, result, resultType, vector, ids); 1564 } 1565 1566 LogicalResult ExtractMapOp::verify() { 1567 if (getSourceVectorType().getRank() != getResultType().getRank()) 1568 return emitOpError("expected source and destination vectors of same rank"); 1569 unsigned numId = 0; 1570 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) { 1571 if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) != 1572 0) 1573 return emitOpError("source vector dimensions must be a multiple of " 1574 "destination vector dimensions"); 1575 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1576 numId++; 1577 } 1578 if (numId != ids().size()) 1579 return emitOpError("expected number of ids must match the number of " 1580 "dimensions distributed"); 1581 return success(); 1582 } 1583 1584 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) { 1585 auto insert = vector().getDefiningOp<vector::InsertMapOp>(); 1586 if (insert == nullptr || getType() != insert.vector().getType() || 1587 ids() != insert.ids()) 1588 return {}; 1589 return insert.vector(); 1590 } 1591 1592 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) { 1593 assert(multiplicity.empty()); 1594 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) { 1595 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1596 multiplicity.push_back(getSourceVectorType().getDimSize(i) / 1597 getResultType().getDimSize(i)); 1598 } 1599 } 1600 1601 template <typename MapOp> 1602 AffineMap calculateImplicitMap(MapOp op) { 1603 SmallVector<AffineExpr, 4> perm; 1604 // Check which dimension have a multiplicity greater than 1 and associated 1605 // them to the IDs in order. 1606 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) { 1607 if (op.getSourceVectorType().getDimSize(i) != 1608 op.getResultType().getDimSize(i)) 1609 perm.push_back(getAffineDimExpr(i, op.getContext())); 1610 } 1611 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm, 1612 op.getContext()); 1613 return map; 1614 } 1615 1616 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } 1617 1618 //===----------------------------------------------------------------------===// 1619 // FmaOp 1620 //===----------------------------------------------------------------------===// 1621 1622 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { 1623 return llvm::to_vector<4>(getVectorType().getShape()); 1624 } 1625 1626 //===----------------------------------------------------------------------===// 1627 // BroadcastOp 1628 //===----------------------------------------------------------------------===// 1629 1630 BroadcastableToResult 1631 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, 1632 std::pair<int, int> *mismatchingDims) { 1633 // Broadcast scalar to vector of the same element type. 1634 if (srcType.isIntOrIndexOrFloat() && dstVectorType && 1635 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) 1636 return BroadcastableToResult::Success; 1637 // From now on, only vectors broadcast. 1638 VectorType srcVectorType = srcType.dyn_cast<VectorType>(); 1639 if (!srcVectorType) 1640 return BroadcastableToResult::SourceTypeNotAVector; 1641 1642 int64_t srcRank = srcVectorType.getRank(); 1643 int64_t dstRank = dstVectorType.getRank(); 1644 if (srcRank > dstRank) 1645 return BroadcastableToResult::SourceRankHigher; 1646 // Source has an exact match or singleton value for all trailing dimensions 1647 // (all leading dimensions are simply duplicated). 1648 int64_t lead = dstRank - srcRank; 1649 for (int64_t r = 0; r < srcRank; ++r) { 1650 int64_t srcDim = srcVectorType.getDimSize(r); 1651 int64_t dstDim = dstVectorType.getDimSize(lead + r); 1652 if (srcDim != 1 && srcDim != dstDim) { 1653 if (mismatchingDims) { 1654 mismatchingDims->first = srcDim; 1655 mismatchingDims->second = dstDim; 1656 } 1657 return BroadcastableToResult::DimensionMismatch; 1658 } 1659 } 1660 1661 return BroadcastableToResult::Success; 1662 } 1663 1664 LogicalResult BroadcastOp::verify() { 1665 std::pair<int, int> mismatchingDims; 1666 BroadcastableToResult res = 1667 isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); 1668 if (res == BroadcastableToResult::Success) 1669 return success(); 1670 if (res == BroadcastableToResult::SourceRankHigher) 1671 return emitOpError("source rank higher than destination rank"); 1672 if (res == BroadcastableToResult::DimensionMismatch) 1673 return emitOpError("dimension mismatch (") 1674 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; 1675 if (res == BroadcastableToResult::SourceTypeNotAVector) 1676 return emitOpError("source type is not a vector"); 1677 llvm_unreachable("unexpected vector.broadcast op error"); 1678 } 1679 1680 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 1681 if (getSourceType() == getVectorType()) 1682 return source(); 1683 if (!operands[0]) 1684 return {}; 1685 auto vectorType = getVectorType(); 1686 if (operands[0].getType().isIntOrIndexOrFloat()) 1687 return DenseElementsAttr::get(vectorType, operands[0]); 1688 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>()) 1689 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()); 1690 return {}; 1691 } 1692 1693 namespace { 1694 1695 // Fold broadcast1(broadcast2(x)) into broadcast1(x). 1696 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { 1697 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 1698 1699 LogicalResult matchAndRewrite(BroadcastOp broadcastOp, 1700 PatternRewriter &rewriter) const override { 1701 auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>(); 1702 if (!srcBroadcast) 1703 return failure(); 1704 rewriter.replaceOpWithNewOp<BroadcastOp>( 1705 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source()); 1706 return success(); 1707 } 1708 }; 1709 } // namespace 1710 1711 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 1712 MLIRContext *context) { 1713 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by 1714 // calling `populateCastAwayVectorLeadingOneDimPatterns` 1715 results.add<BroadcastFolder>(context); 1716 } 1717 1718 //===----------------------------------------------------------------------===// 1719 // ShuffleOp 1720 //===----------------------------------------------------------------------===// 1721 1722 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, 1723 Value v2, ArrayRef<int64_t> mask) { 1724 result.addOperands({v1, v2}); 1725 auto maskAttr = getVectorSubscriptAttr(builder, mask); 1726 auto v1Type = v1.getType().cast<VectorType>(); 1727 auto shape = llvm::to_vector<4>(v1Type.getShape()); 1728 shape[0] = mask.size(); 1729 result.addTypes(VectorType::get(shape, v1Type.getElementType())); 1730 result.addAttribute(getMaskAttrName(), maskAttr); 1731 } 1732 1733 void ShuffleOp::print(OpAsmPrinter &p) { 1734 p << " " << v1() << ", " << v2() << " " << mask(); 1735 p.printOptionalAttrDict((*this)->getAttrs(), {ShuffleOp::getMaskAttrName()}); 1736 p << " : " << v1().getType() << ", " << v2().getType(); 1737 } 1738 1739 LogicalResult ShuffleOp::verify() { 1740 VectorType resultType = getVectorType(); 1741 VectorType v1Type = getV1VectorType(); 1742 VectorType v2Type = getV2VectorType(); 1743 // Verify ranks. 1744 int64_t resRank = resultType.getRank(); 1745 int64_t v1Rank = v1Type.getRank(); 1746 int64_t v2Rank = v2Type.getRank(); 1747 if (resRank != v1Rank || v1Rank != v2Rank) 1748 return emitOpError("rank mismatch"); 1749 // Verify all but leading dimension sizes. 1750 for (int64_t r = 1; r < v1Rank; ++r) { 1751 int64_t resDim = resultType.getDimSize(r); 1752 int64_t v1Dim = v1Type.getDimSize(r); 1753 int64_t v2Dim = v2Type.getDimSize(r); 1754 if (resDim != v1Dim || v1Dim != v2Dim) 1755 return emitOpError("dimension mismatch"); 1756 } 1757 // Verify mask length. 1758 auto maskAttr = mask().getValue(); 1759 int64_t maskLength = maskAttr.size(); 1760 if (maskLength != resultType.getDimSize(0)) 1761 return emitOpError("mask length mismatch"); 1762 // Verify all indices. 1763 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); 1764 for (const auto &en : llvm::enumerate(maskAttr)) { 1765 auto attr = en.value().dyn_cast<IntegerAttr>(); 1766 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) 1767 return emitOpError("mask index #") << (en.index() + 1) << " out of range"; 1768 } 1769 return success(); 1770 } 1771 1772 ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { 1773 OpAsmParser::OperandType v1, v2; 1774 Attribute attr; 1775 VectorType v1Type, v2Type; 1776 if (parser.parseOperand(v1) || parser.parseComma() || 1777 parser.parseOperand(v2) || 1778 parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(), 1779 result.attributes) || 1780 parser.parseOptionalAttrDict(result.attributes) || 1781 parser.parseColonType(v1Type) || parser.parseComma() || 1782 parser.parseType(v2Type) || 1783 parser.resolveOperand(v1, v1Type, result.operands) || 1784 parser.resolveOperand(v2, v2Type, result.operands)) 1785 return failure(); 1786 // Construct resulting type: leading dimension matches mask length, 1787 // all trailing dimensions match the operands. 1788 auto maskAttr = attr.dyn_cast<ArrayAttr>(); 1789 if (!maskAttr) 1790 return parser.emitError(parser.getNameLoc(), "missing mask attribute"); 1791 int64_t maskLength = maskAttr.size(); 1792 if (maskLength <= 0) 1793 return parser.emitError(parser.getNameLoc(), "invalid mask length"); 1794 int64_t v1Rank = v1Type.getRank(); 1795 SmallVector<int64_t, 4> shape; 1796 shape.reserve(v1Rank); 1797 shape.push_back(maskLength); 1798 for (int64_t r = 1; r < v1Rank; ++r) 1799 shape.push_back(v1Type.getDimSize(r)); 1800 VectorType resType = VectorType::get(shape, v1Type.getElementType()); 1801 parser.addTypeToList(resType, result.types); 1802 return success(); 1803 } 1804 1805 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) { 1806 Attribute lhs = operands.front(), rhs = operands.back(); 1807 if (!lhs || !rhs) 1808 return {}; 1809 1810 auto lhsType = lhs.getType().cast<VectorType>(); 1811 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr 1812 // manipulation. 1813 if (lhsType.getRank() != 1) 1814 return {}; 1815 int64_t lhsSize = lhsType.getDimSize(0); 1816 1817 SmallVector<Attribute> results; 1818 auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>(); 1819 auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>(); 1820 for (const auto &index : this->mask().getAsValueRange<IntegerAttr>()) { 1821 int64_t i = index.getZExtValue(); 1822 if (i >= lhsSize) { 1823 results.push_back(rhsElements[i - lhsSize]); 1824 } else { 1825 results.push_back(lhsElements[i]); 1826 } 1827 } 1828 1829 return DenseElementsAttr::get(getVectorType(), results); 1830 } 1831 1832 //===----------------------------------------------------------------------===// 1833 // InsertElementOp 1834 //===----------------------------------------------------------------------===// 1835 1836 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1837 Value source, Value dest) { 1838 result.addOperands({source, dest}); 1839 result.addTypes(dest.getType()); 1840 } 1841 1842 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1843 Value source, Value dest, Value position) { 1844 result.addOperands({source, dest, position}); 1845 result.addTypes(dest.getType()); 1846 } 1847 1848 LogicalResult InsertElementOp::verify() { 1849 auto dstVectorType = getDestVectorType(); 1850 if (dstVectorType.getRank() == 0) { 1851 if (position()) 1852 return emitOpError("expected position to be empty with 0-D vector"); 1853 return success(); 1854 } 1855 if (dstVectorType.getRank() != 1) 1856 return emitOpError("unexpected >1 vector rank"); 1857 if (!position()) 1858 return emitOpError("expected position for 1-D vector"); 1859 return success(); 1860 } 1861 1862 //===----------------------------------------------------------------------===// 1863 // InsertOp 1864 //===----------------------------------------------------------------------===// 1865 1866 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1867 Value dest, ArrayRef<int64_t> position) { 1868 result.addOperands({source, dest}); 1869 auto positionAttr = getVectorSubscriptAttr(builder, position); 1870 result.addTypes(dest.getType()); 1871 result.addAttribute(getPositionAttrName(), positionAttr); 1872 } 1873 1874 // Convenience builder which assumes the values are constant indices. 1875 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1876 Value dest, ValueRange position) { 1877 SmallVector<int64_t, 4> positionConstants = 1878 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 1879 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 1880 })); 1881 build(builder, result, source, dest, positionConstants); 1882 } 1883 1884 LogicalResult InsertOp::verify() { 1885 auto positionAttr = position().getValue(); 1886 auto destVectorType = getDestVectorType(); 1887 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank())) 1888 return emitOpError( 1889 "expected position attribute of rank smaller than dest vector rank"); 1890 auto srcVectorType = getSourceType().dyn_cast<VectorType>(); 1891 if (srcVectorType && 1892 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() != 1893 static_cast<unsigned>(destVectorType.getRank()))) 1894 return emitOpError("expected position attribute rank + source rank to " 1895 "match dest vector rank"); 1896 if (!srcVectorType && 1897 (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) 1898 return emitOpError( 1899 "expected position attribute rank to match the dest vector rank"); 1900 for (const auto &en : llvm::enumerate(positionAttr)) { 1901 auto attr = en.value().dyn_cast<IntegerAttr>(); 1902 if (!attr || attr.getInt() < 0 || 1903 attr.getInt() >= destVectorType.getDimSize(en.index())) 1904 return emitOpError("expected position attribute #") 1905 << (en.index() + 1) 1906 << " to be a non-negative integer smaller than the corresponding " 1907 "dest vector dimension"; 1908 } 1909 return success(); 1910 } 1911 1912 namespace { 1913 1914 // If insertOp is only inserting unit dimensions it can be transformed to a 1915 // broadcast. 1916 class InsertToBroadcast final : public OpRewritePattern<InsertOp> { 1917 public: 1918 using OpRewritePattern<InsertOp>::OpRewritePattern; 1919 1920 LogicalResult matchAndRewrite(InsertOp insertOp, 1921 PatternRewriter &rewriter) const override { 1922 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>(); 1923 if (!srcVecType || insertOp.getDestVectorType().getNumElements() != 1924 srcVecType.getNumElements()) 1925 return failure(); 1926 rewriter.replaceOpWithNewOp<BroadcastOp>( 1927 insertOp, insertOp.getDestVectorType(), insertOp.source()); 1928 return success(); 1929 } 1930 }; 1931 1932 } // namespace 1933 1934 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, 1935 MLIRContext *context) { 1936 results.add<InsertToBroadcast, BroadcastFolder>(context); 1937 } 1938 1939 // Eliminates insert operations that produce values identical to their source 1940 // value. This happens when the source and destination vectors have identical 1941 // sizes. 1942 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) { 1943 if (position().empty()) 1944 return source(); 1945 return {}; 1946 } 1947 1948 //===----------------------------------------------------------------------===// 1949 // InsertMapOp 1950 //===----------------------------------------------------------------------===// 1951 1952 void InsertMapOp::build(OpBuilder &builder, OperationState &result, 1953 Value vector, Value dest, ValueRange ids) { 1954 InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); 1955 } 1956 1957 LogicalResult InsertMapOp::verify() { 1958 if (getSourceVectorType().getRank() != getResultType().getRank()) 1959 return emitOpError("expected source and destination vectors of same rank"); 1960 unsigned numId = 0; 1961 for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) { 1962 if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) != 1963 0) 1964 return emitOpError( 1965 "destination vector size must be a multiple of source vector size"); 1966 if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i)) 1967 numId++; 1968 } 1969 if (numId != ids().size()) 1970 return emitOpError("expected number of ids must match the number of " 1971 "dimensions distributed"); 1972 return success(); 1973 } 1974 1975 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); } 1976 1977 //===----------------------------------------------------------------------===// 1978 // InsertStridedSliceOp 1979 //===----------------------------------------------------------------------===// 1980 1981 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, 1982 Value source, Value dest, 1983 ArrayRef<int64_t> offsets, 1984 ArrayRef<int64_t> strides) { 1985 result.addOperands({source, dest}); 1986 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 1987 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 1988 result.addTypes(dest.getType()); 1989 result.addAttribute(getOffsetsAttrName(), offsetsAttr); 1990 result.addAttribute(getStridesAttrName(), stridesAttr); 1991 } 1992 1993 // TODO: Should be moved to Tablegen Confined attributes. 1994 template <typename OpType> 1995 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, 1996 ArrayAttr arrayAttr, 1997 ArrayRef<int64_t> shape, 1998 StringRef attrName) { 1999 if (arrayAttr.size() > shape.size()) 2000 return op.emitOpError("expected ") 2001 << attrName << " attribute of rank smaller than vector rank"; 2002 return success(); 2003 } 2004 2005 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 2006 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2007 // Otherwise, the admissible interval is [min, max]. 2008 template <typename OpType> 2009 static LogicalResult 2010 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, 2011 int64_t max, StringRef attrName, 2012 bool halfOpen = true) { 2013 for (auto attr : arrayAttr) { 2014 auto val = attr.cast<IntegerAttr>().getInt(); 2015 auto upper = max; 2016 if (!halfOpen) 2017 upper += 1; 2018 if (val < min || val >= upper) 2019 return op.emitOpError("expected ") << attrName << " to be confined to [" 2020 << min << ", " << upper << ")"; 2021 } 2022 return success(); 2023 } 2024 2025 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 2026 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2027 // Otherwise, the admissible interval is [min, max]. 2028 template <typename OpType> 2029 static LogicalResult 2030 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, 2031 ArrayRef<int64_t> shape, StringRef attrName, 2032 bool halfOpen = true, int64_t min = 0) { 2033 assert(arrayAttr.size() <= shape.size()); 2034 unsigned index = 0; 2035 for (auto it : llvm::zip(arrayAttr, shape)) { 2036 auto val = std::get<0>(it).cast<IntegerAttr>().getInt(); 2037 auto max = std::get<1>(it); 2038 if (!halfOpen) 2039 max += 1; 2040 if (val < min || val >= max) 2041 return op.emitOpError("expected ") 2042 << attrName << " dimension " << index << " to be confined to [" 2043 << min << ", " << max << ")"; 2044 ++index; 2045 } 2046 return success(); 2047 } 2048 2049 // Returns true if all integers in `arrayAttr` are in the interval [min, max}. 2050 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2051 // Otherwise, the admissible interval is [min, max]. 2052 template <typename OpType> 2053 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( 2054 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, 2055 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, 2056 bool halfOpen = true, int64_t min = 1) { 2057 assert(arrayAttr1.size() <= shape.size()); 2058 assert(arrayAttr2.size() <= shape.size()); 2059 unsigned index = 0; 2060 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) { 2061 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt(); 2062 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt(); 2063 auto max = std::get<2>(it); 2064 if (!halfOpen) 2065 max += 1; 2066 if (val1 + val2 < 0 || val1 + val2 >= max) 2067 return op.emitOpError("expected sum(") 2068 << attrName1 << ", " << attrName2 << ") dimension " << index 2069 << " to be confined to [" << min << ", " << max << ")"; 2070 ++index; 2071 } 2072 return success(); 2073 } 2074 2075 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, 2076 MLIRContext *context) { 2077 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { 2078 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); 2079 }); 2080 return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); 2081 } 2082 2083 LogicalResult InsertStridedSliceOp::verify() { 2084 auto sourceVectorType = getSourceVectorType(); 2085 auto destVectorType = getDestVectorType(); 2086 auto offsets = offsetsAttr(); 2087 auto strides = stridesAttr(); 2088 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) 2089 return emitOpError( 2090 "expected offsets of same size as destination vector rank"); 2091 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) 2092 return emitOpError("expected strides of same size as source vector rank"); 2093 if (sourceVectorType.getRank() > destVectorType.getRank()) 2094 return emitOpError( 2095 "expected source rank to be smaller than destination rank"); 2096 2097 auto sourceShape = sourceVectorType.getShape(); 2098 auto destShape = destVectorType.getShape(); 2099 SmallVector<int64_t, 4> sourceShapeAsDestShape( 2100 destShape.size() - sourceShape.size(), 0); 2101 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); 2102 auto offName = InsertStridedSliceOp::getOffsetsAttrName(); 2103 auto stridesName = InsertStridedSliceOp::getStridesAttrName(); 2104 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, 2105 offName)) || 2106 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, 2107 stridesName, 2108 /*halfOpen=*/false)) || 2109 failed(isSumOfIntegerArrayAttrConfinedToShape( 2110 *this, offsets, 2111 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, 2112 offName, "source vector shape", 2113 /*halfOpen=*/false, /*min=*/1))) 2114 return failure(); 2115 2116 return success(); 2117 } 2118 2119 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2120 if (getSourceVectorType() == getDestVectorType()) 2121 return source(); 2122 return {}; 2123 } 2124 2125 //===----------------------------------------------------------------------===// 2126 // OuterProductOp 2127 //===----------------------------------------------------------------------===// 2128 2129 /// Build an op without mask, use the type of `acc` as the return type. 2130 void OuterProductOp::build(OpBuilder &builder, OperationState &result, 2131 Value lhs, Value rhs, Value acc) { 2132 result.addOperands({lhs, rhs, acc}); 2133 result.addTypes(acc.getType()); 2134 } 2135 2136 void OuterProductOp::print(OpAsmPrinter &p) { 2137 p << " " << lhs() << ", " << rhs(); 2138 if (!acc().empty()) { 2139 p << ", " << acc(); 2140 p.printOptionalAttrDict((*this)->getAttrs()); 2141 } 2142 p << " : " << lhs().getType() << ", " << rhs().getType(); 2143 } 2144 2145 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { 2146 SmallVector<OpAsmParser::OperandType, 3> operandsInfo; 2147 Type tLHS, tRHS; 2148 if (parser.parseOperandList(operandsInfo) || 2149 parser.parseOptionalAttrDict(result.attributes) || 2150 parser.parseColonType(tLHS) || parser.parseComma() || 2151 parser.parseType(tRHS)) 2152 return failure(); 2153 if (operandsInfo.size() < 2) 2154 return parser.emitError(parser.getNameLoc(), 2155 "expected at least 2 operands"); 2156 VectorType vLHS = tLHS.dyn_cast<VectorType>(); 2157 VectorType vRHS = tRHS.dyn_cast<VectorType>(); 2158 if (!vLHS) 2159 return parser.emitError(parser.getNameLoc(), 2160 "expected vector type for operand #1"); 2161 VectorType resType = 2162 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, 2163 vLHS.getElementType()) 2164 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); 2165 2166 if (!result.attributes.get(OuterProductOp::getKindAttrName())) { 2167 result.attributes.append( 2168 OuterProductOp::getKindAttrName(), 2169 CombiningKindAttr::get(OuterProductOp::getDefaultKind(), 2170 result.getContext())); 2171 } 2172 2173 return failure( 2174 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || 2175 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || 2176 (operandsInfo.size() > 2 && 2177 parser.resolveOperand(operandsInfo[2], resType, result.operands)) || 2178 parser.addTypeToList(resType, result.types)); 2179 } 2180 2181 LogicalResult OuterProductOp::verify() { 2182 Type tRHS = getOperandTypeRHS(); 2183 VectorType vLHS = getOperandVectorTypeLHS(), 2184 vRHS = tRHS.dyn_cast<VectorType>(), 2185 vACC = getOperandVectorTypeACC(), vRES = getVectorType(); 2186 2187 if (vLHS.getRank() != 1) 2188 return emitOpError("expected 1-d vector for operand #1"); 2189 2190 if (vRHS) { 2191 // Proper OUTER operation. 2192 if (vRHS.getRank() != 1) 2193 return emitOpError("expected 1-d vector for operand #2"); 2194 if (vRES.getRank() != 2) 2195 return emitOpError("expected 2-d vector result"); 2196 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2197 return emitOpError("expected #1 operand dim to match result dim #1"); 2198 if (vRHS.getDimSize(0) != vRES.getDimSize(1)) 2199 return emitOpError("expected #2 operand dim to match result dim #2"); 2200 } else { 2201 // An AXPY operation. 2202 if (vRES.getRank() != 1) 2203 return emitOpError("expected 1-d vector result"); 2204 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2205 return emitOpError("expected #1 operand dim to match result dim #1"); 2206 } 2207 2208 if (vACC && vACC != vRES) 2209 return emitOpError("expected operand #3 of same type as result type"); 2210 2211 // Verify supported combining kind. 2212 if (!isSupportedCombiningKind(kind(), vRES.getElementType())) 2213 return emitOpError("unsupported outerproduct type"); 2214 2215 return success(); 2216 } 2217 2218 //===----------------------------------------------------------------------===// 2219 // ReshapeOp 2220 //===----------------------------------------------------------------------===// 2221 2222 LogicalResult ReshapeOp::verify() { 2223 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. 2224 auto inputVectorType = getInputVectorType(); 2225 auto outputVectorType = getOutputVectorType(); 2226 int64_t inputShapeRank = getNumInputShapeSizes(); 2227 int64_t outputShapeRank = getNumOutputShapeSizes(); 2228 SmallVector<int64_t, 4> fixedVectorSizes; 2229 getFixedVectorSizes(fixedVectorSizes); 2230 int64_t numFixedVectorSizes = fixedVectorSizes.size(); 2231 2232 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) 2233 return emitError("invalid input shape for vector type ") 2234 << inputVectorType; 2235 2236 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) 2237 return emitError("invalid output shape for vector type ") 2238 << outputVectorType; 2239 2240 // Verify that the 'fixedVectorSizes' match an input/output vector shape 2241 // suffix. 2242 unsigned inputVectorRank = inputVectorType.getRank(); 2243 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2244 unsigned index = inputVectorRank - numFixedVectorSizes - i; 2245 if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) 2246 return emitError("fixed vector size must match input vector for dim ") 2247 << i; 2248 } 2249 2250 unsigned outputVectorRank = outputVectorType.getRank(); 2251 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2252 unsigned index = outputVectorRank - numFixedVectorSizes - i; 2253 if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) 2254 return emitError("fixed vector size must match output vector for dim ") 2255 << i; 2256 } 2257 2258 // If all shape operands are produced by constant ops, verify that product 2259 // of dimensions for input/output shape match. 2260 auto isDefByConstant = [](Value operand) { 2261 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 2262 }; 2263 if (llvm::all_of(input_shape(), isDefByConstant) && 2264 llvm::all_of(output_shape(), isDefByConstant)) { 2265 int64_t numInputElements = 1; 2266 for (auto operand : input_shape()) 2267 numInputElements *= 2268 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2269 int64_t numOutputElements = 1; 2270 for (auto operand : output_shape()) 2271 numOutputElements *= 2272 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2273 if (numInputElements != numOutputElements) 2274 return emitError("product of input and output shape sizes must match"); 2275 } 2276 return success(); 2277 } 2278 2279 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) { 2280 populateFromInt64AttrArray(fixed_vector_sizes(), results); 2281 } 2282 2283 //===----------------------------------------------------------------------===// 2284 // ExtractStridedSliceOp 2285 //===----------------------------------------------------------------------===// 2286 2287 // Inference works as follows: 2288 // 1. Add 'sizes' from prefix of dims in 'offsets'. 2289 // 2. Add sizes from 'vectorType' for remaining dims. 2290 static Type inferStridedSliceOpResultType(VectorType vectorType, 2291 ArrayAttr offsets, ArrayAttr sizes, 2292 ArrayAttr strides) { 2293 assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); 2294 SmallVector<int64_t, 4> shape; 2295 shape.reserve(vectorType.getRank()); 2296 unsigned idx = 0; 2297 for (unsigned e = offsets.size(); idx < e; ++idx) 2298 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt()); 2299 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) 2300 shape.push_back(vectorType.getShape()[idx]); 2301 2302 return VectorType::get(shape, vectorType.getElementType()); 2303 } 2304 2305 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, 2306 Value source, ArrayRef<int64_t> offsets, 2307 ArrayRef<int64_t> sizes, 2308 ArrayRef<int64_t> strides) { 2309 result.addOperands(source); 2310 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 2311 auto sizesAttr = getVectorSubscriptAttr(builder, sizes); 2312 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 2313 result.addTypes( 2314 inferStridedSliceOpResultType(source.getType().cast<VectorType>(), 2315 offsetsAttr, sizesAttr, stridesAttr)); 2316 result.addAttribute(getOffsetsAttrName(), offsetsAttr); 2317 result.addAttribute(getSizesAttrName(), sizesAttr); 2318 result.addAttribute(getStridesAttrName(), stridesAttr); 2319 } 2320 2321 LogicalResult ExtractStridedSliceOp::verify() { 2322 auto type = getVectorType(); 2323 auto offsets = offsetsAttr(); 2324 auto sizes = sizesAttr(); 2325 auto strides = stridesAttr(); 2326 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) 2327 return emitOpError("expected offsets, sizes and strides attributes of same size"); 2328 2329 auto shape = type.getShape(); 2330 auto offName = getOffsetsAttrName(); 2331 auto sizesName = getSizesAttrName(); 2332 auto stridesName = getStridesAttrName(); 2333 if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || 2334 failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || 2335 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, 2336 stridesName)) || 2337 failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || 2338 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, 2339 /*halfOpen=*/false, 2340 /*min=*/1)) || 2341 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, 2342 /*halfOpen=*/false)) || 2343 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, 2344 offName, sizesName, 2345 /*halfOpen=*/false))) 2346 return failure(); 2347 2348 auto resultType = 2349 inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); 2350 if (getResult().getType() != resultType) 2351 return emitOpError("expected result type to be ") << resultType; 2352 2353 return success(); 2354 } 2355 2356 // When the source of ExtractStrided comes from a chain of InsertStrided ops try 2357 // to use the source of the InsertStrided ops if we can detect that the 2358 // extracted vector is a subset of one of the vector inserted. 2359 static LogicalResult 2360 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { 2361 // Helper to extract integer out of ArrayAttr. 2362 auto getElement = [](ArrayAttr array, int idx) { 2363 return array[idx].cast<IntegerAttr>().getInt(); 2364 }; 2365 ArrayAttr extractOffsets = op.offsets(); 2366 ArrayAttr extractStrides = op.strides(); 2367 ArrayAttr extractSizes = op.sizes(); 2368 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 2369 while (insertOp) { 2370 if (op.getVectorType().getRank() != 2371 insertOp.getSourceVectorType().getRank()) 2372 return failure(); 2373 ArrayAttr insertOffsets = insertOp.offsets(); 2374 ArrayAttr insertStrides = insertOp.strides(); 2375 // If the rank of extract is greater than the rank of insert, we are likely 2376 // extracting a partial chunk of the vector inserted. 2377 if (extractOffsets.size() > insertOffsets.size()) 2378 return failure(); 2379 bool patialoverlap = false; 2380 bool disjoint = false; 2381 SmallVector<int64_t, 4> offsetDiffs; 2382 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 2383 if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) 2384 return failure(); 2385 int64_t start = getElement(insertOffsets, dim); 2386 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); 2387 int64_t offset = getElement(extractOffsets, dim); 2388 int64_t size = getElement(extractSizes, dim); 2389 // Check if the start of the extract offset is in the interval inserted. 2390 if (start <= offset && offset < end) { 2391 // If the extract interval overlaps but is not fully included we may 2392 // have a partial overlap that will prevent any folding. 2393 if (offset + size > end) 2394 patialoverlap = true; 2395 offsetDiffs.push_back(offset - start); 2396 continue; 2397 } 2398 disjoint = true; 2399 break; 2400 } 2401 // The extract element chunk is a subset of the insert element. 2402 if (!disjoint && !patialoverlap) { 2403 op.setOperand(insertOp.source()); 2404 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2405 OpBuilder b(op.getContext()); 2406 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(), 2407 b.getI64ArrayAttr(offsetDiffs)); 2408 return success(); 2409 } 2410 // If the chunk extracted is disjoint from the chunk inserted, keep looking 2411 // in the insert chain. 2412 if (disjoint) 2413 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 2414 else { 2415 // The extracted vector partially overlap the inserted vector, we cannot 2416 // fold. 2417 return failure(); 2418 } 2419 } 2420 return failure(); 2421 } 2422 2423 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2424 if (getVectorType() == getResult().getType()) 2425 return vector(); 2426 if (succeeded(foldExtractStridedOpFromInsertChain(*this))) 2427 return getResult(); 2428 return {}; 2429 } 2430 2431 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { 2432 populateFromInt64AttrArray(offsets(), results); 2433 } 2434 2435 namespace { 2436 2437 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to 2438 // ConstantMaskOp. 2439 class StridedSliceConstantMaskFolder final 2440 : public OpRewritePattern<ExtractStridedSliceOp> { 2441 public: 2442 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2443 2444 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2445 PatternRewriter &rewriter) const override { 2446 // Return if 'extractStridedSliceOp' operand is not defined by a 2447 // ConstantMaskOp. 2448 auto *defOp = extractStridedSliceOp.vector().getDefiningOp(); 2449 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); 2450 if (!constantMaskOp) 2451 return failure(); 2452 // Return if 'extractStridedSliceOp' has non-unit strides. 2453 if (extractStridedSliceOp.hasNonUnitStrides()) 2454 return failure(); 2455 // Gather constant mask dimension sizes. 2456 SmallVector<int64_t, 4> maskDimSizes; 2457 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); 2458 // Gather strided slice offsets and sizes. 2459 SmallVector<int64_t, 4> sliceOffsets; 2460 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets); 2461 SmallVector<int64_t, 4> sliceSizes; 2462 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes); 2463 2464 // Compute slice of vector mask region. 2465 SmallVector<int64_t, 4> sliceMaskDimSizes; 2466 assert(sliceOffsets.size() == maskDimSizes.size()); 2467 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { 2468 int64_t maskDimSize = std::get<0>(it); 2469 int64_t sliceOffset = std::get<1>(it); 2470 int64_t sliceSize = std::get<2>(it); 2471 int64_t sliceMaskDimSize = std::max( 2472 static_cast<int64_t>(0), 2473 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); 2474 sliceMaskDimSizes.push_back(sliceMaskDimSize); 2475 } 2476 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked 2477 // region is a conjunction of mask dim intervals). 2478 if (llvm::is_contained(sliceMaskDimSizes, 0)) 2479 sliceMaskDimSizes.assign(maskDimSizes.size(), 0); 2480 2481 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask 2482 // region. 2483 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 2484 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), 2485 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); 2486 return success(); 2487 } 2488 }; 2489 2490 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. 2491 class StridedSliceConstantFolder final 2492 : public OpRewritePattern<ExtractStridedSliceOp> { 2493 public: 2494 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2495 2496 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2497 PatternRewriter &rewriter) const override { 2498 // Return if 'extractStridedSliceOp' operand is not defined by a 2499 // ConstantOp. 2500 auto constantOp = 2501 extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>(); 2502 if (!constantOp) 2503 return failure(); 2504 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 2505 if (!dense) 2506 return failure(); 2507 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), 2508 dense.getSplatValue<Attribute>()); 2509 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, 2510 newAttr); 2511 return success(); 2512 } 2513 }; 2514 2515 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to 2516 // BroadcastOp(ExtractStrideSliceOp). 2517 class StridedSliceBroadcast final 2518 : public OpRewritePattern<ExtractStridedSliceOp> { 2519 public: 2520 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2521 2522 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2523 PatternRewriter &rewriter) const override { 2524 auto broadcast = op.vector().getDefiningOp<BroadcastOp>(); 2525 if (!broadcast) 2526 return failure(); 2527 auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>(); 2528 unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; 2529 auto dstVecType = op.getType().cast<VectorType>(); 2530 unsigned dstRank = dstVecType.getRank(); 2531 unsigned rankDiff = dstRank - srcRrank; 2532 // Check if the most inner dimensions of the source of the broadcast are the 2533 // same as the destination of the extract. If this is the case we can just 2534 // use a broadcast as the original dimensions are untouched. 2535 bool lowerDimMatch = true; 2536 for (unsigned i = 0; i < srcRrank; i++) { 2537 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { 2538 lowerDimMatch = false; 2539 break; 2540 } 2541 } 2542 Value source = broadcast.source(); 2543 if (!lowerDimMatch) { 2544 // The inner dimensions don't match, it means we need to extract from the 2545 // source of the orignal broadcast and then broadcast the extracted value. 2546 source = rewriter.create<ExtractStridedSliceOp>( 2547 op->getLoc(), source, 2548 getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), 2549 getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), 2550 getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); 2551 } 2552 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); 2553 return success(); 2554 } 2555 }; 2556 2557 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. 2558 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { 2559 public: 2560 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2561 2562 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2563 PatternRewriter &rewriter) const override { 2564 auto splat = op.vector().getDefiningOp<SplatOp>(); 2565 if (!splat) 2566 return failure(); 2567 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input()); 2568 return success(); 2569 } 2570 }; 2571 2572 } // namespace 2573 2574 void ExtractStridedSliceOp::getCanonicalizationPatterns( 2575 RewritePatternSet &results, MLIRContext *context) { 2576 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> 2577 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. 2578 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder, 2579 StridedSliceBroadcast, StridedSliceSplat>(context); 2580 } 2581 2582 //===----------------------------------------------------------------------===// 2583 // TransferReadOp 2584 //===----------------------------------------------------------------------===// 2585 2586 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). 2587 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2588 VectorType vectorType, Value source, 2589 ValueRange indices, AffineMapAttr permutationMapAttr, 2590 /*optional*/ ArrayAttr inBoundsAttr) { 2591 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2592 Value padding = builder.create<arith::ConstantOp>( 2593 result.location, elemType, builder.getZeroAttr(elemType)); 2594 build(builder, result, vectorType, source, indices, permutationMapAttr, 2595 padding, /*mask=*/Value(), inBoundsAttr); 2596 } 2597 2598 /// 2. Builder that sets padding to zero an empty mask (variant without attrs). 2599 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2600 VectorType vectorType, Value source, 2601 ValueRange indices, AffineMap permutationMap, 2602 Optional<ArrayRef<bool>> inBounds) { 2603 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2604 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2605 ? builder.getBoolArrayAttr(inBounds.getValue()) 2606 : ArrayAttr(); 2607 build(builder, result, vectorType, source, indices, permutationMapAttr, 2608 inBoundsAttr); 2609 } 2610 2611 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. 2612 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2613 VectorType vectorType, Value source, 2614 ValueRange indices, Value padding, 2615 Optional<ArrayRef<bool>> inBounds) { 2616 AffineMap permutationMap = getTransferMinorIdentityMap( 2617 source.getType().cast<ShapedType>(), vectorType); 2618 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2619 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2620 ? builder.getBoolArrayAttr(inBounds.getValue()) 2621 : ArrayAttr(); 2622 build(builder, result, vectorType, source, indices, permutationMapAttr, 2623 padding, 2624 /*mask=*/Value(), inBoundsAttr); 2625 } 2626 2627 /// 4. Builder that sets padding to zero and permutation map to 2628 /// 'getMinorIdentityMap'. 2629 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2630 VectorType vectorType, Value source, 2631 ValueRange indices, 2632 Optional<ArrayRef<bool>> inBounds) { 2633 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2634 Value padding = builder.create<arith::ConstantOp>( 2635 result.location, elemType, builder.getZeroAttr(elemType)); 2636 build(builder, result, vectorType, source, indices, padding, inBounds); 2637 } 2638 2639 template <typename EmitFun> 2640 static LogicalResult verifyPermutationMap(AffineMap permutationMap, 2641 EmitFun emitOpError) { 2642 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); 2643 for (auto expr : permutationMap.getResults()) { 2644 auto dim = expr.dyn_cast<AffineDimExpr>(); 2645 auto zero = expr.dyn_cast<AffineConstantExpr>(); 2646 if (zero) { 2647 if (zero.getValue() != 0) { 2648 return emitOpError( 2649 "requires a projected permutation_map (at most one dim or the zero " 2650 "constant can appear in each result)"); 2651 } 2652 continue; 2653 } 2654 if (!dim) { 2655 return emitOpError("requires a projected permutation_map (at most one " 2656 "dim or the zero constant can appear in each result)"); 2657 } 2658 if (seen[dim.getPosition()]) { 2659 return emitOpError( 2660 "requires a permutation_map that is a permutation (found one dim " 2661 "used more than once)"); 2662 } 2663 seen[dim.getPosition()] = true; 2664 } 2665 return success(); 2666 } 2667 2668 static LogicalResult 2669 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, 2670 VectorType vectorType, VectorType maskType, 2671 AffineMap permutationMap, ArrayAttr inBounds) { 2672 if (op->hasAttr("masked")) { 2673 return op->emitOpError("masked attribute has been removed. " 2674 "Use in_bounds instead."); 2675 } 2676 2677 if (!shapedType.isa<MemRefType, RankedTensorType>()) 2678 return op->emitOpError( 2679 "requires source to be a memref or ranked tensor type"); 2680 2681 auto elementType = shapedType.getElementType(); 2682 DataLayout dataLayout = DataLayout::closest(op); 2683 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) { 2684 // Memref or tensor has vector element type. 2685 unsigned sourceVecSize = 2686 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * 2687 vectorElementType.getShape().back(); 2688 unsigned resultVecSize = 2689 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * 2690 vectorType.getShape().back(); 2691 if (resultVecSize % sourceVecSize != 0) 2692 return op->emitOpError( 2693 "requires the bitwidth of the minor 1-D vector to be an integral " 2694 "multiple of the bitwidth of the minor 1-D vector of the source"); 2695 2696 unsigned sourceVecEltRank = vectorElementType.getRank(); 2697 unsigned resultVecRank = vectorType.getRank(); 2698 if (sourceVecEltRank > resultVecRank) 2699 return op->emitOpError( 2700 "requires source vector element and vector result ranks to match."); 2701 unsigned rankOffset = resultVecRank - sourceVecEltRank; 2702 // Check that permutation map results match 'rankOffset' of vector type. 2703 if (permutationMap.getNumResults() != rankOffset) 2704 return op->emitOpError("requires a permutation_map with result dims of " 2705 "the same rank as the vector type"); 2706 2707 if (maskType) 2708 return op->emitOpError("does not support masks with vector element type"); 2709 } else { 2710 // Memref or tensor has scalar element type. 2711 unsigned minorSize = 2712 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); 2713 unsigned resultVecSize = 2714 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; 2715 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) 2716 return op->emitOpError( 2717 "requires the bitwidth of the minor 1-D vector to be an integral " 2718 "multiple of the bitwidth of the source element type"); 2719 2720 // Check that permutation map results match rank of vector type. 2721 if (permutationMap.getNumResults() != vectorType.getRank()) 2722 return op->emitOpError("requires a permutation_map with result dims of " 2723 "the same rank as the vector type"); 2724 2725 VectorType expectedMaskType = 2726 vector::detail::transferMaskType(vectorType, permutationMap); 2727 if (maskType && expectedMaskType != maskType) 2728 return op->emitOpError("expects mask type consistent with permutation " 2729 "map: ") 2730 << maskType; 2731 } 2732 2733 if (permutationMap.getNumSymbols() != 0) 2734 return op->emitOpError("requires permutation_map without symbols"); 2735 2736 if (permutationMap.getNumInputs() != shapedType.getRank()) 2737 return op->emitOpError("requires a permutation_map with input dims of the " 2738 "same rank as the source type"); 2739 2740 if (inBounds) { 2741 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size())) 2742 return op->emitOpError("expects the optional in_bounds attr of same rank " 2743 "as permutation_map results: ") 2744 << AffineMapAttr::get(permutationMap) 2745 << " vs inBounds of size: " << inBounds.size(); 2746 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) 2747 if (permutationMap.getResult(i).isa<AffineConstantExpr>() && 2748 !inBounds.getValue()[i].cast<BoolAttr>().getValue()) 2749 return op->emitOpError("requires broadcast dimensions to be in-bounds"); 2750 } 2751 2752 return success(); 2753 } 2754 2755 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { 2756 SmallVector<StringRef, 3> elidedAttrs; 2757 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); 2758 if (op.permutation_map().isMinorIdentity()) 2759 elidedAttrs.push_back(op.getPermutationMapAttrName()); 2760 bool elideInBounds = true; 2761 if (auto inBounds = op.in_bounds()) { 2762 for (auto attr : *inBounds) { 2763 if (attr.template cast<BoolAttr>().getValue()) { 2764 elideInBounds = false; 2765 break; 2766 } 2767 } 2768 } 2769 if (elideInBounds) 2770 elidedAttrs.push_back(op.getInBoundsAttrName()); 2771 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 2772 } 2773 2774 void TransferReadOp::print(OpAsmPrinter &p) { 2775 p << " " << source() << "[" << indices() << "], " << padding(); 2776 if (mask()) 2777 p << ", " << mask(); 2778 printTransferAttrs(p, *this); 2779 p << " : " << getShapedType() << ", " << getVectorType(); 2780 } 2781 2782 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { 2783 auto &builder = parser.getBuilder(); 2784 SMLoc typesLoc; 2785 OpAsmParser::OperandType sourceInfo; 2786 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 2787 OpAsmParser::OperandType paddingInfo; 2788 SmallVector<Type, 2> types; 2789 OpAsmParser::OperandType maskInfo; 2790 // Parsing with support for paddingValue. 2791 if (parser.parseOperand(sourceInfo) || 2792 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 2793 parser.parseComma() || parser.parseOperand(paddingInfo)) 2794 return failure(); 2795 ParseResult hasMask = parser.parseOptionalComma(); 2796 if (hasMask.succeeded()) { 2797 parser.parseOperand(maskInfo); 2798 } 2799 if (parser.parseOptionalAttrDict(result.attributes) || 2800 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 2801 return failure(); 2802 if (types.size() != 2) 2803 return parser.emitError(typesLoc, "requires two types"); 2804 auto indexType = builder.getIndexType(); 2805 auto shapedType = types[0].dyn_cast<ShapedType>(); 2806 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 2807 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 2808 VectorType vectorType = types[1].dyn_cast<VectorType>(); 2809 if (!vectorType) 2810 return parser.emitError(typesLoc, "requires vector type"); 2811 auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); 2812 Attribute mapAttr = result.attributes.get(permutationAttrName); 2813 if (!mapAttr) { 2814 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 2815 // Update `mapAttr` that is used later to determine mask type. 2816 mapAttr = AffineMapAttr::get(permMap); 2817 result.attributes.set(permutationAttrName, mapAttr); 2818 } 2819 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || 2820 parser.resolveOperands(indexInfo, indexType, result.operands) || 2821 parser.resolveOperand(paddingInfo, shapedType.getElementType(), 2822 result.operands)) 2823 return failure(); 2824 if (hasMask.succeeded()) { 2825 if (shapedType.getElementType().dyn_cast<VectorType>()) 2826 return parser.emitError( 2827 maskInfo.location, "does not support masks with vector element type"); 2828 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue(); 2829 // Instead of adding the mask type as an op type, compute it based on the 2830 // vector type and the permutation map (to keep the type signature small). 2831 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); 2832 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 2833 return failure(); 2834 } 2835 result.addAttribute( 2836 TransferReadOp::getOperandSegmentSizeAttr(), 2837 builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1, 2838 static_cast<int32_t>(hasMask.succeeded())})); 2839 return parser.addTypeToList(vectorType, result.types); 2840 } 2841 2842 LogicalResult TransferReadOp::verify() { 2843 // Consistency of elemental types in source and vector. 2844 ShapedType shapedType = getShapedType(); 2845 VectorType vectorType = getVectorType(); 2846 VectorType maskType = getMaskType(); 2847 auto paddingType = padding().getType(); 2848 auto permutationMap = permutation_map(); 2849 auto sourceElementType = shapedType.getElementType(); 2850 2851 if (static_cast<int64_t>(indices().size()) != shapedType.getRank()) 2852 return emitOpError("requires ") << shapedType.getRank() << " indices"; 2853 2854 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 2855 shapedType, vectorType, maskType, permutationMap, 2856 in_bounds() ? *in_bounds() : ArrayAttr()))) 2857 return failure(); 2858 2859 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) { 2860 // Source has vector element type. 2861 // Check that 'sourceVectorElementType' and 'paddingType' types match. 2862 if (sourceVectorElementType != paddingType) 2863 return emitOpError( 2864 "requires source element type and padding type to match."); 2865 2866 } else { 2867 // Check that 'paddingType' is valid to store in a vector type. 2868 if (!VectorType::isValidElementType(paddingType)) 2869 return emitOpError("requires valid padding vector elemental type"); 2870 2871 // Check that padding type and vector element types match. 2872 if (paddingType != sourceElementType) 2873 return emitOpError( 2874 "requires formal padding and source of the same elemental type"); 2875 } 2876 2877 return verifyPermutationMap(permutationMap, 2878 [&](Twine t) { return emitOpError(t); }); 2879 } 2880 2881 /// This is a common class used for patterns of the form 2882 /// ``` 2883 /// someop(memrefcast) -> someop 2884 /// ``` 2885 /// It folds the source of the memref.cast into the root operation directly. 2886 static LogicalResult foldMemRefCast(Operation *op) { 2887 bool folded = false; 2888 for (OpOperand &operand : op->getOpOperands()) { 2889 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 2890 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 2891 operand.set(castOp.getOperand()); 2892 folded = true; 2893 } 2894 } 2895 return success(folded); 2896 } 2897 2898 static LogicalResult foldTensorCast(Operation *op) { 2899 bool folded = false; 2900 for (OpOperand &operand : op->getOpOperands()) { 2901 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 2902 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 2903 operand.set(castOp.getOperand()); 2904 folded = true; 2905 } 2906 } 2907 return success(folded); 2908 } 2909 2910 template <typename TransferOp> 2911 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { 2912 // TODO: support more aggressive createOrFold on: 2913 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` 2914 if (op.getShapedType().isDynamicDim(indicesIdx)) 2915 return false; 2916 Value index = op.indices()[indicesIdx]; 2917 auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>(); 2918 if (!cstOp) 2919 return false; 2920 2921 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); 2922 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); 2923 2924 return cstOp.value() + vectorSize <= sourceSize; 2925 } 2926 2927 template <typename TransferOp> 2928 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { 2929 // TODO: support 0-d corner case. 2930 // TODO: Be less conservative. 2931 if (op.getTransferRank() == 0) 2932 return failure(); 2933 AffineMap permutationMap = op.permutation_map(); 2934 bool changed = false; 2935 SmallVector<bool, 4> newInBounds; 2936 newInBounds.reserve(op.getTransferRank()); 2937 for (unsigned i = 0; i < op.getTransferRank(); ++i) { 2938 // Already marked as in-bounds, nothing to see here. 2939 if (op.isDimInBounds(i)) { 2940 newInBounds.push_back(true); 2941 continue; 2942 } 2943 // Currently out-of-bounds, check whether we can statically determine it is 2944 // inBounds. 2945 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>(); 2946 assert(dimExpr && "Broadcast dims must be in-bounds"); 2947 auto inBounds = 2948 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); 2949 newInBounds.push_back(inBounds); 2950 // We commit the pattern if it is "more inbounds". 2951 changed |= inBounds; 2952 } 2953 if (!changed) 2954 return failure(); 2955 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2956 OpBuilder b(op.getContext()); 2957 op->setAttr(TransferOp::getInBoundsAttrName(), 2958 b.getBoolArrayAttr(newInBounds)); 2959 return success(); 2960 } 2961 2962 /// ``` 2963 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 2964 /// : vector<1x4xf32>, tensor<4x4xf32> 2965 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} 2966 /// : tensor<4x4xf32>, vector<1x4xf32> 2967 /// ``` 2968 /// -> Folds into 2969 /// ``` 2970 /// %v0 2971 /// ``` 2972 static Value foldRAW(TransferReadOp readOp) { 2973 if (!readOp.getShapedType().isa<RankedTensorType>()) 2974 return {}; 2975 auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>(); 2976 while (defWrite) { 2977 if (checkSameValueRAW(defWrite, readOp)) 2978 return defWrite.vector(); 2979 if (!isDisjointTransferIndices( 2980 cast<VectorTransferOpInterface>(defWrite.getOperation()), 2981 cast<VectorTransferOpInterface>(readOp.getOperation()))) 2982 break; 2983 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 2984 } 2985 return {}; 2986 } 2987 2988 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { 2989 if (Value vec = foldRAW(*this)) 2990 return vec; 2991 /// transfer_read(memrefcast) -> transfer_read 2992 if (succeeded(foldTransferInBoundsAttribute(*this))) 2993 return getResult(); 2994 if (succeeded(foldMemRefCast(*this))) 2995 return getResult(); 2996 if (succeeded(foldTensorCast(*this))) 2997 return getResult(); 2998 return OpFoldResult(); 2999 } 3000 3001 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { 3002 return llvm::to_vector<4>(getVectorType().getShape()); 3003 } 3004 3005 void TransferReadOp::getEffects( 3006 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3007 &effects) { 3008 if (getShapedType().isa<MemRefType>()) 3009 effects.emplace_back(MemoryEffects::Read::get(), source(), 3010 SideEffects::DefaultResource::get()); 3011 } 3012 3013 namespace { 3014 /// Fold transfer_reads of a tensor.extract_slice op. E.g.: 3015 /// 3016 /// ``` 3017 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] 3018 /// : tensor<?x?xf32> to tensor<?x?xf32> 3019 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} 3020 /// : tensor<?x?xf32>, vector<4x5xf32> 3021 /// ``` 3022 /// is rewritten to: 3023 /// ``` 3024 /// %p0 = arith.addi %a, %e : index 3025 /// %p1 = arith.addi %b, %f : index 3026 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} 3027 /// : tensor<?x?xf32>, vector<4x5xf32> 3028 /// ``` 3029 struct FoldExtractSliceIntoTransferRead 3030 : public OpRewritePattern<TransferReadOp> { 3031 public: 3032 using OpRewritePattern<TransferReadOp>::OpRewritePattern; 3033 3034 LogicalResult matchAndRewrite(TransferReadOp xferOp, 3035 PatternRewriter &rewriter) const override { 3036 // TODO: support 0-d corner case. 3037 if (xferOp.getTransferRank() == 0) 3038 return failure(); 3039 if (xferOp.hasOutOfBoundsDim()) 3040 return failure(); 3041 if (!xferOp.permutation_map().isIdentity()) 3042 return failure(); 3043 if (xferOp.mask()) 3044 return failure(); 3045 auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>(); 3046 if (!extractOp) 3047 return failure(); 3048 if (!extractOp.hasUnitStride()) 3049 return failure(); 3050 3051 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3052 // dims are exactly the leading dims. I.e. the following is illegal: 3053 // ``` 3054 // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : 3055 // tensor<2x1x4xf32> to tensor<2x4xf32> 3056 // %1 = vector.transfer_read %0[0,0], %cst : 3057 // tensor<2x4xf32>, vector<2x4xf32> 3058 // ``` 3059 // 3060 // Cannot fold into: 3061 // ``` 3062 // %0 = vector.transfer_read %t[0,0,0], %cst : 3063 // tensor<2x1x4xf32>, vector<2x4xf32> 3064 // ``` 3065 // For this, check the trailing `vectorRank` dims of the extract_slice 3066 // result tensor match the trailing dims of the inferred result tensor. 3067 int64_t rankReduced = 3068 extractOp.getSourceType().getRank() - extractOp.getType().getRank(); 3069 int64_t vectorRank = xferOp.getVectorType().getRank(); 3070 RankedTensorType inferredDestTensorType = 3071 tensor::ExtractSliceOp::inferResultType( 3072 extractOp.getSourceType(), extractOp.getMixedOffsets(), 3073 extractOp.getMixedSizes(), extractOp.getMixedStrides()); 3074 auto actualDestTensorShape = extractOp.getType().getShape(); 3075 if (rankReduced > 0 && 3076 actualDestTensorShape.take_back(vectorRank) != 3077 inferredDestTensorType.getShape().take_back(vectorRank)) 3078 return failure(); 3079 3080 SmallVector<Value> newIndices; 3081 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced 3082 // indices first. 3083 for (int64_t i = 0; i < rankReduced; ++i) { 3084 OpFoldResult offset = extractOp.getMixedOffsets()[i]; 3085 newIndices.push_back(getValueOrCreateConstantIndexOp( 3086 rewriter, extractOp.getLoc(), offset)); 3087 } 3088 for (const auto &it : llvm::enumerate(xferOp.indices())) { 3089 OpFoldResult offset = 3090 extractOp.getMixedOffsets()[it.index() + rankReduced]; 3091 newIndices.push_back(rewriter.create<arith::AddIOp>( 3092 xferOp->getLoc(), it.value(), 3093 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), 3094 offset))); 3095 } 3096 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3097 rewriter.replaceOpWithNewOp<TransferReadOp>( 3098 xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, 3099 xferOp.padding(), ArrayRef<bool>{inBounds}); 3100 3101 return success(); 3102 } 3103 }; 3104 } // namespace 3105 3106 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3107 MLIRContext *context) { 3108 results.add<FoldExtractSliceIntoTransferRead>(context); 3109 } 3110 3111 //===----------------------------------------------------------------------===// 3112 // TransferWriteOp 3113 //===----------------------------------------------------------------------===// 3114 3115 /// 1. Builder with type inference. 3116 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3117 Value vector, Value dest, ValueRange indices, 3118 AffineMapAttr permutationMapAttr, 3119 /*optional*/ Value mask, 3120 /*optional*/ ArrayAttr inBoundsAttr) { 3121 Type resultType = dest.getType().dyn_cast<RankedTensorType>(); 3122 build(builder, result, resultType, vector, dest, indices, permutationMapAttr, 3123 mask, inBoundsAttr); 3124 } 3125 3126 /// 2. Builder with type inference that sets an empty mask (variant with attrs). 3127 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3128 Value vector, Value dest, ValueRange indices, 3129 AffineMapAttr permutationMapAttr, 3130 /*optional*/ ArrayAttr inBoundsAttr) { 3131 build(builder, result, vector, dest, indices, permutationMapAttr, 3132 /*mask=*/Value(), inBoundsAttr); 3133 } 3134 3135 /// 3. Builder with type inference that sets an empty mask (variant without 3136 /// attrs) 3137 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3138 Value vector, Value dest, ValueRange indices, 3139 AffineMap permutationMap, 3140 Optional<ArrayRef<bool>> inBounds) { 3141 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 3142 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 3143 ? builder.getBoolArrayAttr(inBounds.getValue()) 3144 : ArrayAttr(); 3145 build(builder, result, vector, dest, indices, permutationMapAttr, 3146 /*mask=*/Value(), inBoundsAttr); 3147 } 3148 3149 /// 4. Builder with type inference that sets an empty mask and sets permutation 3150 /// map to 'getMinorIdentityMap'. 3151 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3152 Value vector, Value dest, ValueRange indices, 3153 Optional<ArrayRef<bool>> inBounds) { 3154 auto vectorType = vector.getType().cast<VectorType>(); 3155 AffineMap permutationMap = getTransferMinorIdentityMap( 3156 dest.getType().cast<ShapedType>(), vectorType); 3157 build(builder, result, vector, dest, indices, permutationMap, inBounds); 3158 } 3159 3160 ParseResult TransferWriteOp::parse(OpAsmParser &parser, 3161 OperationState &result) { 3162 auto &builder = parser.getBuilder(); 3163 SMLoc typesLoc; 3164 OpAsmParser::OperandType vectorInfo, sourceInfo; 3165 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 3166 SmallVector<Type, 2> types; 3167 OpAsmParser::OperandType maskInfo; 3168 if (parser.parseOperand(vectorInfo) || parser.parseComma() || 3169 parser.parseOperand(sourceInfo) || 3170 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) 3171 return failure(); 3172 ParseResult hasMask = parser.parseOptionalComma(); 3173 if (hasMask.succeeded() && parser.parseOperand(maskInfo)) 3174 return failure(); 3175 if (parser.parseOptionalAttrDict(result.attributes) || 3176 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 3177 return failure(); 3178 if (types.size() != 2) 3179 return parser.emitError(typesLoc, "requires two types"); 3180 auto indexType = builder.getIndexType(); 3181 VectorType vectorType = types[0].dyn_cast<VectorType>(); 3182 if (!vectorType) 3183 return parser.emitError(typesLoc, "requires vector type"); 3184 ShapedType shapedType = types[1].dyn_cast<ShapedType>(); 3185 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 3186 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 3187 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); 3188 auto attr = result.attributes.get(permutationAttrName); 3189 if (!attr) { 3190 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 3191 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); 3192 } 3193 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || 3194 parser.resolveOperand(sourceInfo, shapedType, result.operands) || 3195 parser.resolveOperands(indexInfo, indexType, result.operands)) 3196 return failure(); 3197 if (hasMask.succeeded()) { 3198 if (shapedType.getElementType().dyn_cast<VectorType>()) 3199 return parser.emitError( 3200 maskInfo.location, "does not support masks with vector element type"); 3201 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); 3202 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 3203 return failure(); 3204 } 3205 result.addAttribute( 3206 TransferWriteOp::getOperandSegmentSizeAttr(), 3207 builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()), 3208 static_cast<int32_t>(hasMask.succeeded())})); 3209 return failure(shapedType.isa<RankedTensorType>() && 3210 parser.addTypeToList(shapedType, result.types)); 3211 } 3212 3213 void TransferWriteOp::print(OpAsmPrinter &p) { 3214 p << " " << vector() << ", " << source() << "[" << indices() << "]"; 3215 if (mask()) 3216 p << ", " << mask(); 3217 printTransferAttrs(p, *this); 3218 p << " : " << getVectorType() << ", " << getShapedType(); 3219 } 3220 3221 LogicalResult TransferWriteOp::verify() { 3222 // Consistency of elemental types in shape and vector. 3223 ShapedType shapedType = getShapedType(); 3224 VectorType vectorType = getVectorType(); 3225 VectorType maskType = getMaskType(); 3226 auto permutationMap = permutation_map(); 3227 3228 if (llvm::size(indices()) != shapedType.getRank()) 3229 return emitOpError("requires ") << shapedType.getRank() << " indices"; 3230 3231 // We do not allow broadcast dimensions on TransferWriteOps for the moment, 3232 // as the semantics is unclear. This can be revisited later if necessary. 3233 if (hasBroadcastDim()) 3234 return emitOpError("should not have broadcast dimensions"); 3235 3236 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 3237 shapedType, vectorType, maskType, permutationMap, 3238 in_bounds() ? *in_bounds() : ArrayAttr()))) 3239 return failure(); 3240 3241 return verifyPermutationMap(permutationMap, 3242 [&](Twine t) { return emitOpError(t); }); 3243 } 3244 3245 /// Fold: 3246 /// ``` 3247 /// %t1 = ... 3248 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : 3249 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3250 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} : 3251 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3252 /// ``` 3253 /// 3254 /// into: 3255 /// 3256 /// ``` 3257 /// %t0 3258 /// ``` 3259 /// 3260 /// The producer of t1 may or may not be DCE'd depending on whether it is a 3261 /// block argument or has side effects. 3262 static LogicalResult foldReadInitWrite(TransferWriteOp write, 3263 ArrayRef<Attribute>, 3264 SmallVectorImpl<OpFoldResult> &results) { 3265 // TODO: support 0-d corner case. 3266 if (write.getTransferRank() == 0) 3267 return failure(); 3268 auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>(); 3269 // If not operating on tensors, bail. 3270 if (!rankedTensorType) 3271 return failure(); 3272 // If no read, bail. 3273 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3274 if (!read) 3275 return failure(); 3276 // TODO: support 0-d corner case. 3277 if (read.getTransferRank() == 0) 3278 return failure(); 3279 // For now, only accept minor identity. Future: composition is minor identity. 3280 if (!read.permutation_map().isMinorIdentity() || 3281 !write.permutation_map().isMinorIdentity()) 3282 return failure(); 3283 // Bail on mismatching ranks. 3284 if (read.getTransferRank() != write.getTransferRank()) 3285 return failure(); 3286 // Bail on potential out-of-bounds accesses. 3287 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) 3288 return failure(); 3289 // Tensor types must be the same. 3290 if (read.source().getType() != rankedTensorType) 3291 return failure(); 3292 // Vector types must be the same. 3293 if (read.getVectorType() != write.getVectorType()) 3294 return failure(); 3295 // Vector and Tensor shapes must match. 3296 if (read.getVectorType().getShape() != rankedTensorType.getShape()) 3297 return failure(); 3298 // If any index is nonzero. 3299 auto isNotConstantZero = [](Value v) { 3300 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>(); 3301 return !cstOp || cstOp.value() != 0; 3302 }; 3303 if (llvm::any_of(read.indices(), isNotConstantZero) || 3304 llvm::any_of(write.indices(), isNotConstantZero)) 3305 return failure(); 3306 // Success. 3307 results.push_back(read.source()); 3308 return success(); 3309 } 3310 3311 static bool checkSameValueWAR(vector::TransferReadOp read, 3312 vector::TransferWriteOp write) { 3313 return read.source() == write.source() && read.indices() == write.indices() && 3314 read.permutation_map() == write.permutation_map() && 3315 read.getVectorType() == write.getVectorType() && !read.mask() && 3316 !write.mask(); 3317 } 3318 /// Fold transfer_write write after read: 3319 /// ``` 3320 /// %t0 = ... 3321 /// %v = vector.transfer_read %t0[%c0...] : 3322 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3323 /// %t1 = vector.transfer_write %v, %t0[%c0...] : 3324 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3325 /// ``` 3326 /// 3327 /// into: 3328 /// 3329 /// ``` 3330 /// %t0 3331 /// ``` 3332 static LogicalResult foldWAR(TransferWriteOp write, 3333 SmallVectorImpl<OpFoldResult> &results) { 3334 if (!write.source().getType().isa<RankedTensorType>()) 3335 return failure(); 3336 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3337 if (!read) 3338 return failure(); 3339 3340 if (!checkSameValueWAR(read, write)) 3341 return failure(); 3342 results.push_back(read.source()); 3343 return success(); 3344 } 3345 3346 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands, 3347 SmallVectorImpl<OpFoldResult> &results) { 3348 if (succeeded(foldReadInitWrite(*this, operands, results))) 3349 return success(); 3350 if (succeeded(foldWAR(*this, results))) 3351 return success(); 3352 if (succeeded(foldTransferInBoundsAttribute(*this))) 3353 return success(); 3354 return foldMemRefCast(*this); 3355 } 3356 3357 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { 3358 return llvm::to_vector<4>(getVectorType().getShape()); 3359 } 3360 3361 void TransferWriteOp::getEffects( 3362 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3363 &effects) { 3364 if (getShapedType().isa<MemRefType>()) 3365 effects.emplace_back(MemoryEffects::Write::get(), source(), 3366 SideEffects::DefaultResource::get()); 3367 } 3368 3369 namespace { 3370 /// Remove dead transfer write from the SSA chain so that it an be eliminated by 3371 /// DCE 3372 /// ``` 3373 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3374 /// : vector<1x4xf32>, tensor<4x4xf32> 3375 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} 3376 /// : vector<1x4xf32>, tensor<4x4xf32> 3377 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3378 /// : vector<1x4xf32>, tensor<4x4xf32> 3379 /// ``` 3380 /// 3381 /// into: 3382 /// 3383 /// ``` 3384 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3385 /// : vector<1x4xf32>, tensor<4x4xf32> 3386 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]} 3387 /// : vector<1x4xf32>, tensor<4x4xf32> 3388 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3389 /// : vector<1x4xf32>, tensor<4x4xf32> 3390 /// ``` 3391 /// 3392 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have 3393 /// any other uses. 3394 class FoldWaw final : public OpRewritePattern<TransferWriteOp> { 3395 public: 3396 using OpRewritePattern<TransferWriteOp>::OpRewritePattern; 3397 LogicalResult matchAndRewrite(TransferWriteOp writeOp, 3398 PatternRewriter &rewriter) const override { 3399 if (!writeOp.getShapedType().isa<RankedTensorType>()) 3400 return failure(); 3401 vector::TransferWriteOp writeToModify = writeOp; 3402 3403 auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>(); 3404 while (defWrite) { 3405 if (checkSameValueWAW(writeOp, defWrite)) { 3406 writeToModify.sourceMutable().assign(defWrite.source()); 3407 return success(); 3408 } 3409 if (!isDisjointTransferIndices( 3410 cast<VectorTransferOpInterface>(defWrite.getOperation()), 3411 cast<VectorTransferOpInterface>(writeOp.getOperation()))) 3412 break; 3413 // If the previous write op doesn't have any other use we an safely look 3414 // at the previous store to see if it can be removed. 3415 if (!defWrite->hasOneUse()) 3416 break; 3417 writeToModify = defWrite; 3418 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 3419 } 3420 return failure(); 3421 } 3422 }; 3423 3424 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write 3425 /// could directly write to the insert_slice's destination. E.g.: 3426 /// 3427 /// ``` 3428 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} 3429 /// : vector<4x5xf32>, tensor<4x5xf32> 3430 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] 3431 /// : tensor<4x5xf32> into tensor<?x?xf32> 3432 /// ``` 3433 /// is rewritten to: 3434 /// ``` 3435 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} 3436 /// : vector<4x5xf32>, tensor<?x?xf32> 3437 /// ``` 3438 struct FoldInsertSliceIntoTransferWrite 3439 : public OpRewritePattern<tensor::InsertSliceOp> { 3440 public: 3441 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 3442 3443 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 3444 PatternRewriter &rewriter) const override { 3445 if (!insertOp.hasUnitStride()) 3446 return failure(); 3447 3448 auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>(); 3449 if (!xferOp) 3450 return failure(); 3451 // TODO: support 0-d corner case. 3452 if (xferOp.getTransferRank() == 0) 3453 return failure(); 3454 3455 if (xferOp.hasOutOfBoundsDim()) 3456 return failure(); 3457 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) 3458 return failure(); 3459 if (xferOp.mask()) 3460 return failure(); 3461 // Fold only if the TransferWriteOp completely overwrites the `source` with 3462 // a vector. I.e., the result of the TransferWriteOp is a new tensor whose 3463 // content is the data of the vector. 3464 if (!llvm::equal(xferOp.getVectorType().getShape(), 3465 xferOp.getShapedType().getShape())) 3466 return failure(); 3467 if (!xferOp.permutation_map().isIdentity()) 3468 return failure(); 3469 3470 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3471 // dims are exactly the leading dims. I.e. the following is illegal: 3472 // ``` 3473 // %0 = vector.transfer_write %v, %t[0,0], %cst : 3474 // vector<2x4xf32>, tensor<2x4xf32> 3475 // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : 3476 // tensor<2x4xf32> into tensor<2x1x4xf32> 3477 // ``` 3478 // 3479 // Cannot fold into: 3480 // ``` 3481 // %0 = vector.transfer_write %v, %t[0,0,0], %cst : 3482 // vector<2x4xf32>, tensor<2x1x4xf32> 3483 // ``` 3484 // For this, check the trailing `vectorRank` dims of the insert_slice result 3485 // tensor match the trailing dims of the inferred result tensor. 3486 int64_t rankReduced = 3487 insertOp.getType().getRank() - insertOp.getSourceType().getRank(); 3488 int64_t vectorRank = xferOp.getVectorType().getRank(); 3489 RankedTensorType inferredSourceTensorType = 3490 tensor::ExtractSliceOp::inferResultType( 3491 insertOp.getType(), insertOp.getMixedOffsets(), 3492 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 3493 auto actualSourceTensorShape = insertOp.getSourceType().getShape(); 3494 if (rankReduced > 0 && 3495 actualSourceTensorShape.take_back(vectorRank) != 3496 inferredSourceTensorType.getShape().take_back(vectorRank)) 3497 return failure(); 3498 3499 SmallVector<Value> indices = getValueOrCreateConstantIndexOp( 3500 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); 3501 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3502 rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(), 3503 insertOp.dest(), indices, 3504 ArrayRef<bool>{inBounds}); 3505 return success(); 3506 } 3507 }; 3508 } // namespace 3509 3510 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, 3511 MLIRContext *context) { 3512 results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context); 3513 } 3514 3515 //===----------------------------------------------------------------------===// 3516 // LoadOp 3517 //===----------------------------------------------------------------------===// 3518 3519 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, 3520 MemRefType memRefTy) { 3521 if (!isLastMemrefDimUnitStride(memRefTy)) 3522 return op->emitOpError("most minor memref dim must have unit stride"); 3523 return success(); 3524 } 3525 3526 LogicalResult vector::LoadOp::verify() { 3527 VectorType resVecTy = getVectorType(); 3528 MemRefType memRefTy = getMemRefType(); 3529 3530 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) 3531 return failure(); 3532 3533 // Checks for vector memrefs. 3534 Type memElemTy = memRefTy.getElementType(); 3535 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3536 if (memVecTy != resVecTy) 3537 return emitOpError("base memref and result vector types should match"); 3538 memElemTy = memVecTy.getElementType(); 3539 } 3540 3541 if (resVecTy.getElementType() != memElemTy) 3542 return emitOpError("base and result element types should match"); 3543 if (llvm::size(indices()) != memRefTy.getRank()) 3544 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 3545 return success(); 3546 } 3547 3548 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) { 3549 if (succeeded(foldMemRefCast(*this))) 3550 return getResult(); 3551 return OpFoldResult(); 3552 } 3553 3554 //===----------------------------------------------------------------------===// 3555 // StoreOp 3556 //===----------------------------------------------------------------------===// 3557 3558 LogicalResult vector::StoreOp::verify() { 3559 VectorType valueVecTy = getVectorType(); 3560 MemRefType memRefTy = getMemRefType(); 3561 3562 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) 3563 return failure(); 3564 3565 // Checks for vector memrefs. 3566 Type memElemTy = memRefTy.getElementType(); 3567 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3568 if (memVecTy != valueVecTy) 3569 return emitOpError( 3570 "base memref and valueToStore vector types should match"); 3571 memElemTy = memVecTy.getElementType(); 3572 } 3573 3574 if (valueVecTy.getElementType() != memElemTy) 3575 return emitOpError("base and valueToStore element type should match"); 3576 if (llvm::size(indices()) != memRefTy.getRank()) 3577 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 3578 return success(); 3579 } 3580 3581 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands, 3582 SmallVectorImpl<OpFoldResult> &results) { 3583 return foldMemRefCast(*this); 3584 } 3585 3586 //===----------------------------------------------------------------------===// 3587 // MaskedLoadOp 3588 //===----------------------------------------------------------------------===// 3589 3590 LogicalResult MaskedLoadOp::verify() { 3591 VectorType maskVType = getMaskVectorType(); 3592 VectorType passVType = getPassThruVectorType(); 3593 VectorType resVType = getVectorType(); 3594 MemRefType memType = getMemRefType(); 3595 3596 if (resVType.getElementType() != memType.getElementType()) 3597 return emitOpError("base and result element type should match"); 3598 if (llvm::size(indices()) != memType.getRank()) 3599 return emitOpError("requires ") << memType.getRank() << " indices"; 3600 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3601 return emitOpError("expected result dim to match mask dim"); 3602 if (resVType != passVType) 3603 return emitOpError("expected pass_thru of same type as result type"); 3604 return success(); 3605 } 3606 3607 namespace { 3608 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { 3609 public: 3610 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern; 3611 LogicalResult matchAndRewrite(MaskedLoadOp load, 3612 PatternRewriter &rewriter) const override { 3613 switch (get1DMaskFormat(load.mask())) { 3614 case MaskFormat::AllTrue: 3615 rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(), 3616 load.base(), load.indices()); 3617 return success(); 3618 case MaskFormat::AllFalse: 3619 rewriter.replaceOp(load, load.pass_thru()); 3620 return success(); 3621 case MaskFormat::Unknown: 3622 return failure(); 3623 } 3624 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad"); 3625 } 3626 }; 3627 } // namespace 3628 3629 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3630 MLIRContext *context) { 3631 results.add<MaskedLoadFolder>(context); 3632 } 3633 3634 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) { 3635 if (succeeded(foldMemRefCast(*this))) 3636 return getResult(); 3637 return OpFoldResult(); 3638 } 3639 3640 //===----------------------------------------------------------------------===// 3641 // MaskedStoreOp 3642 //===----------------------------------------------------------------------===// 3643 3644 LogicalResult MaskedStoreOp::verify() { 3645 VectorType maskVType = getMaskVectorType(); 3646 VectorType valueVType = getVectorType(); 3647 MemRefType memType = getMemRefType(); 3648 3649 if (valueVType.getElementType() != memType.getElementType()) 3650 return emitOpError("base and valueToStore element type should match"); 3651 if (llvm::size(indices()) != memType.getRank()) 3652 return emitOpError("requires ") << memType.getRank() << " indices"; 3653 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3654 return emitOpError("expected valueToStore dim to match mask dim"); 3655 return success(); 3656 } 3657 3658 namespace { 3659 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { 3660 public: 3661 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern; 3662 LogicalResult matchAndRewrite(MaskedStoreOp store, 3663 PatternRewriter &rewriter) const override { 3664 switch (get1DMaskFormat(store.mask())) { 3665 case MaskFormat::AllTrue: 3666 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3667 store, store.valueToStore(), store.base(), store.indices()); 3668 return success(); 3669 case MaskFormat::AllFalse: 3670 rewriter.eraseOp(store); 3671 return success(); 3672 case MaskFormat::Unknown: 3673 return failure(); 3674 } 3675 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore"); 3676 } 3677 }; 3678 } // namespace 3679 3680 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3681 MLIRContext *context) { 3682 results.add<MaskedStoreFolder>(context); 3683 } 3684 3685 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands, 3686 SmallVectorImpl<OpFoldResult> &results) { 3687 return foldMemRefCast(*this); 3688 } 3689 3690 //===----------------------------------------------------------------------===// 3691 // GatherOp 3692 //===----------------------------------------------------------------------===// 3693 3694 LogicalResult GatherOp::verify() { 3695 VectorType indVType = getIndexVectorType(); 3696 VectorType maskVType = getMaskVectorType(); 3697 VectorType resVType = getVectorType(); 3698 MemRefType memType = getMemRefType(); 3699 3700 if (resVType.getElementType() != memType.getElementType()) 3701 return emitOpError("base and result element type should match"); 3702 if (llvm::size(indices()) != memType.getRank()) 3703 return emitOpError("requires ") << memType.getRank() << " indices"; 3704 if (resVType.getDimSize(0) != indVType.getDimSize(0)) 3705 return emitOpError("expected result dim to match indices dim"); 3706 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3707 return emitOpError("expected result dim to match mask dim"); 3708 if (resVType != getPassThruVectorType()) 3709 return emitOpError("expected pass_thru of same type as result type"); 3710 return success(); 3711 } 3712 3713 namespace { 3714 class GatherFolder final : public OpRewritePattern<GatherOp> { 3715 public: 3716 using OpRewritePattern<GatherOp>::OpRewritePattern; 3717 LogicalResult matchAndRewrite(GatherOp gather, 3718 PatternRewriter &rewriter) const override { 3719 switch (get1DMaskFormat(gather.mask())) { 3720 case MaskFormat::AllTrue: 3721 return failure(); // no unmasked equivalent 3722 case MaskFormat::AllFalse: 3723 rewriter.replaceOp(gather, gather.pass_thru()); 3724 return success(); 3725 case MaskFormat::Unknown: 3726 return failure(); 3727 } 3728 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); 3729 } 3730 }; 3731 } // namespace 3732 3733 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, 3734 MLIRContext *context) { 3735 results.add<GatherFolder>(context); 3736 } 3737 3738 //===----------------------------------------------------------------------===// 3739 // ScatterOp 3740 //===----------------------------------------------------------------------===// 3741 3742 LogicalResult ScatterOp::verify() { 3743 VectorType indVType = getIndexVectorType(); 3744 VectorType maskVType = getMaskVectorType(); 3745 VectorType valueVType = getVectorType(); 3746 MemRefType memType = getMemRefType(); 3747 3748 if (valueVType.getElementType() != memType.getElementType()) 3749 return emitOpError("base and valueToStore element type should match"); 3750 if (llvm::size(indices()) != memType.getRank()) 3751 return emitOpError("requires ") << memType.getRank() << " indices"; 3752 if (valueVType.getDimSize(0) != indVType.getDimSize(0)) 3753 return emitOpError("expected valueToStore dim to match indices dim"); 3754 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3755 return emitOpError("expected valueToStore dim to match mask dim"); 3756 return success(); 3757 } 3758 3759 namespace { 3760 class ScatterFolder final : public OpRewritePattern<ScatterOp> { 3761 public: 3762 using OpRewritePattern<ScatterOp>::OpRewritePattern; 3763 LogicalResult matchAndRewrite(ScatterOp scatter, 3764 PatternRewriter &rewriter) const override { 3765 switch (get1DMaskFormat(scatter.mask())) { 3766 case MaskFormat::AllTrue: 3767 return failure(); // no unmasked equivalent 3768 case MaskFormat::AllFalse: 3769 rewriter.eraseOp(scatter); 3770 return success(); 3771 case MaskFormat::Unknown: 3772 return failure(); 3773 } 3774 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); 3775 } 3776 }; 3777 } // namespace 3778 3779 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, 3780 MLIRContext *context) { 3781 results.add<ScatterFolder>(context); 3782 } 3783 3784 //===----------------------------------------------------------------------===// 3785 // ExpandLoadOp 3786 //===----------------------------------------------------------------------===// 3787 3788 LogicalResult ExpandLoadOp::verify() { 3789 VectorType maskVType = getMaskVectorType(); 3790 VectorType passVType = getPassThruVectorType(); 3791 VectorType resVType = getVectorType(); 3792 MemRefType memType = getMemRefType(); 3793 3794 if (resVType.getElementType() != memType.getElementType()) 3795 return emitOpError("base and result element type should match"); 3796 if (llvm::size(indices()) != memType.getRank()) 3797 return emitOpError("requires ") << memType.getRank() << " indices"; 3798 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3799 return emitOpError("expected result dim to match mask dim"); 3800 if (resVType != passVType) 3801 return emitOpError("expected pass_thru of same type as result type"); 3802 return success(); 3803 } 3804 3805 namespace { 3806 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { 3807 public: 3808 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern; 3809 LogicalResult matchAndRewrite(ExpandLoadOp expand, 3810 PatternRewriter &rewriter) const override { 3811 switch (get1DMaskFormat(expand.mask())) { 3812 case MaskFormat::AllTrue: 3813 rewriter.replaceOpWithNewOp<vector::LoadOp>( 3814 expand, expand.getType(), expand.base(), expand.indices()); 3815 return success(); 3816 case MaskFormat::AllFalse: 3817 rewriter.replaceOp(expand, expand.pass_thru()); 3818 return success(); 3819 case MaskFormat::Unknown: 3820 return failure(); 3821 } 3822 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder"); 3823 } 3824 }; 3825 } // namespace 3826 3827 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3828 MLIRContext *context) { 3829 results.add<ExpandLoadFolder>(context); 3830 } 3831 3832 //===----------------------------------------------------------------------===// 3833 // CompressStoreOp 3834 //===----------------------------------------------------------------------===// 3835 3836 LogicalResult CompressStoreOp::verify() { 3837 VectorType maskVType = getMaskVectorType(); 3838 VectorType valueVType = getVectorType(); 3839 MemRefType memType = getMemRefType(); 3840 3841 if (valueVType.getElementType() != memType.getElementType()) 3842 return emitOpError("base and valueToStore element type should match"); 3843 if (llvm::size(indices()) != memType.getRank()) 3844 return emitOpError("requires ") << memType.getRank() << " indices"; 3845 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3846 return emitOpError("expected valueToStore dim to match mask dim"); 3847 return success(); 3848 } 3849 3850 namespace { 3851 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { 3852 public: 3853 using OpRewritePattern<CompressStoreOp>::OpRewritePattern; 3854 LogicalResult matchAndRewrite(CompressStoreOp compress, 3855 PatternRewriter &rewriter) const override { 3856 switch (get1DMaskFormat(compress.mask())) { 3857 case MaskFormat::AllTrue: 3858 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3859 compress, compress.valueToStore(), compress.base(), 3860 compress.indices()); 3861 return success(); 3862 case MaskFormat::AllFalse: 3863 rewriter.eraseOp(compress); 3864 return success(); 3865 case MaskFormat::Unknown: 3866 return failure(); 3867 } 3868 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder"); 3869 } 3870 }; 3871 } // namespace 3872 3873 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3874 MLIRContext *context) { 3875 results.add<CompressStoreFolder>(context); 3876 } 3877 3878 //===----------------------------------------------------------------------===// 3879 // ShapeCastOp 3880 //===----------------------------------------------------------------------===// 3881 3882 /// Returns true if each element of 'a' is equal to the product of a contiguous 3883 /// sequence of the elements of 'b'. Returns false otherwise. 3884 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 3885 unsigned rankA = a.size(); 3886 unsigned rankB = b.size(); 3887 assert(rankA < rankB); 3888 3889 unsigned i = 0; 3890 unsigned j = 0; 3891 while (i < rankA && j < rankB) { 3892 int64_t dimA = a[i]; 3893 int64_t dimB = 1; 3894 while (dimB < dimA && j < rankB) 3895 dimB *= b[j++]; 3896 if (dimA != dimB) 3897 break; 3898 ++i; 3899 3900 // Handle the case when trailing dimensions are of size 1. 3901 // Include them into the contiguous sequence. 3902 auto isOne = [](int64_t v) { return v == 1; }; 3903 if (i < rankA && llvm::all_of(a.slice(i), isOne)) 3904 i = rankA; 3905 if (j < rankB && llvm::all_of(b.slice(j), isOne)) 3906 j = rankB; 3907 } 3908 3909 return i == rankA && j == rankB; 3910 } 3911 3912 static LogicalResult verifyVectorShapeCast(Operation *op, 3913 VectorType sourceVectorType, 3914 VectorType resultVectorType) { 3915 // Check that element type is the same. 3916 if (sourceVectorType.getElementType() != resultVectorType.getElementType()) 3917 return op->emitOpError("source/result vectors must have same element type"); 3918 auto sourceShape = sourceVectorType.getShape(); 3919 auto resultShape = resultVectorType.getShape(); 3920 3921 // Check that product of source dim sizes matches product of result dim sizes. 3922 int64_t sourceDimProduct = std::accumulate( 3923 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); 3924 int64_t resultDimProduct = std::accumulate( 3925 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); 3926 if (sourceDimProduct != resultDimProduct) 3927 return op->emitOpError("source/result number of elements must match"); 3928 3929 // Check that expanding/contracting rank cases. 3930 unsigned sourceRank = sourceVectorType.getRank(); 3931 unsigned resultRank = resultVectorType.getRank(); 3932 if (sourceRank < resultRank) { 3933 if (!isValidShapeCast(sourceShape, resultShape)) 3934 return op->emitOpError("invalid shape cast"); 3935 } else if (sourceRank > resultRank) { 3936 if (!isValidShapeCast(resultShape, sourceShape)) 3937 return op->emitOpError("invalid shape cast"); 3938 } 3939 return success(); 3940 } 3941 3942 LogicalResult ShapeCastOp::verify() { 3943 auto sourceVectorType = source().getType().dyn_cast_or_null<VectorType>(); 3944 auto resultVectorType = result().getType().dyn_cast_or_null<VectorType>(); 3945 3946 // Check if source/result are of vector type. 3947 if (sourceVectorType && resultVectorType) 3948 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); 3949 3950 return success(); 3951 } 3952 3953 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { 3954 // Nop shape cast. 3955 if (source().getType() == result().getType()) 3956 return source(); 3957 3958 // Canceling shape casts. 3959 if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) { 3960 if (result().getType() == otherOp.source().getType()) 3961 return otherOp.source(); 3962 3963 // Only allows valid transitive folding. 3964 VectorType srcType = otherOp.source().getType().cast<VectorType>(); 3965 VectorType resultType = getResult().getType().cast<VectorType>(); 3966 if (srcType.getRank() < resultType.getRank()) { 3967 if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) 3968 return {}; 3969 } else if (srcType.getRank() > resultType.getRank()) { 3970 if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) 3971 return {}; 3972 } else { 3973 return {}; 3974 } 3975 3976 setOperand(otherOp.source()); 3977 return getResult(); 3978 } 3979 return {}; 3980 } 3981 3982 namespace { 3983 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. 3984 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { 3985 public: 3986 using OpRewritePattern<ShapeCastOp>::OpRewritePattern; 3987 3988 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, 3989 PatternRewriter &rewriter) const override { 3990 auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>(); 3991 if (!constantOp) 3992 return failure(); 3993 // Only handle splat for now. 3994 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 3995 if (!dense) 3996 return failure(); 3997 auto newAttr = 3998 DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(), 3999 dense.getSplatValue<Attribute>()); 4000 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); 4001 return success(); 4002 } 4003 }; 4004 4005 } // namespace 4006 4007 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 4008 MLIRContext *context) { 4009 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. 4010 results.add<ShapeCastConstantFolder>(context); 4011 } 4012 4013 //===----------------------------------------------------------------------===// 4014 // VectorBitCastOp 4015 //===----------------------------------------------------------------------===// 4016 4017 LogicalResult BitCastOp::verify() { 4018 auto sourceVectorType = getSourceVectorType(); 4019 auto resultVectorType = getResultVectorType(); 4020 4021 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { 4022 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) 4023 return emitOpError("dimension size mismatch at: ") << i; 4024 } 4025 4026 DataLayout dataLayout = DataLayout::closest(*this); 4027 auto sourceElementBits = 4028 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); 4029 auto resultElementBits = 4030 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); 4031 4032 if (sourceVectorType.getRank() == 0) { 4033 if (sourceElementBits != resultElementBits) 4034 return emitOpError("source/result bitwidth of the 0-D vector element " 4035 "types must be equal"); 4036 } else if (sourceElementBits * sourceVectorType.getShape().back() != 4037 resultElementBits * resultVectorType.getShape().back()) { 4038 return emitOpError( 4039 "source/result bitwidth of the minor 1-D vectors must be equal"); 4040 } 4041 4042 return success(); 4043 } 4044 4045 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) { 4046 // Nop cast. 4047 if (source().getType() == result().getType()) 4048 return source(); 4049 4050 // Canceling bitcasts. 4051 if (auto otherOp = source().getDefiningOp<BitCastOp>()) 4052 if (result().getType() == otherOp.source().getType()) 4053 return otherOp.source(); 4054 4055 Attribute sourceConstant = operands.front(); 4056 if (!sourceConstant) 4057 return {}; 4058 4059 Type srcElemType = getSourceVectorType().getElementType(); 4060 Type dstElemType = getResultVectorType().getElementType(); 4061 4062 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) { 4063 if (floatPack.isSplat()) { 4064 auto splat = floatPack.getSplatValue<FloatAttr>(); 4065 4066 // Casting fp16 into fp32. 4067 if (srcElemType.isF16() && dstElemType.isF32()) { 4068 uint32_t bits = static_cast<uint32_t>( 4069 splat.getValue().bitcastToAPInt().getZExtValue()); 4070 // Duplicate the 16-bit pattern. 4071 bits = (bits << 16) | (bits & 0xffff); 4072 APInt intBits(32, bits); 4073 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); 4074 return DenseElementsAttr::get(getResultVectorType(), floatBits); 4075 } 4076 } 4077 } 4078 4079 return {}; 4080 } 4081 4082 //===----------------------------------------------------------------------===// 4083 // TypeCastOp 4084 //===----------------------------------------------------------------------===// 4085 4086 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { 4087 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>(); 4088 SmallVector<int64_t, 8> res(memRefType.getShape().begin(), 4089 memRefType.getShape().end()); 4090 if (vectorType) 4091 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); 4092 return res; 4093 } 4094 4095 /// Build the canonical memRefType with a single vector. 4096 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. 4097 void TypeCastOp::build(OpBuilder &builder, OperationState &result, 4098 Value source) { 4099 result.addOperands(source); 4100 MemRefType memRefType = source.getType().cast<MemRefType>(); 4101 VectorType vectorType = 4102 VectorType::get(extractShape(memRefType), 4103 getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); 4104 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), 4105 memRefType.getMemorySpace())); 4106 } 4107 4108 LogicalResult TypeCastOp::verify() { 4109 MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); 4110 if (!canonicalType.getLayout().isIdentity()) 4111 return emitOpError("expects operand to be a memref with identity layout"); 4112 if (!getResultMemRefType().getLayout().isIdentity()) 4113 return emitOpError("expects result to be a memref with identity layout"); 4114 if (getResultMemRefType().getMemorySpace() != 4115 getMemRefType().getMemorySpace()) 4116 return emitOpError("expects result in same memory space"); 4117 4118 auto sourceType = getMemRefType(); 4119 auto resultType = getResultMemRefType(); 4120 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != 4121 getElementTypeOrSelf(getElementTypeOrSelf(resultType))) 4122 return emitOpError( 4123 "expects result and operand with same underlying scalar type: ") 4124 << resultType; 4125 if (extractShape(sourceType) != extractShape(resultType)) 4126 return emitOpError( 4127 "expects concatenated result and operand shapes to be equal: ") 4128 << resultType; 4129 return success(); 4130 } 4131 4132 //===----------------------------------------------------------------------===// 4133 // TransposeOp 4134 //===----------------------------------------------------------------------===// 4135 4136 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, 4137 Value vector, ArrayRef<int64_t> transp) { 4138 VectorType vt = vector.getType().cast<VectorType>(); 4139 SmallVector<int64_t, 4> transposedShape(vt.getRank()); 4140 for (unsigned i = 0; i < transp.size(); ++i) 4141 transposedShape[i] = vt.getShape()[transp[i]]; 4142 4143 result.addOperands(vector); 4144 result.addTypes(VectorType::get(transposedShape, vt.getElementType())); 4145 result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); 4146 } 4147 4148 // Eliminates transpose operations, which produce values identical to their 4149 // input values. This happens when the dimensions of the input vector remain in 4150 // their original order after the transpose operation. 4151 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) { 4152 SmallVector<int64_t, 4> transp; 4153 getTransp(transp); 4154 4155 // Check if the permutation of the dimensions contains sequential values: 4156 // {0, 1, 2, ...}. 4157 for (int64_t i = 0, e = transp.size(); i < e; i++) { 4158 if (transp[i] != i) 4159 return {}; 4160 } 4161 4162 return vector(); 4163 } 4164 4165 LogicalResult vector::TransposeOp::verify() { 4166 VectorType vectorType = getVectorType(); 4167 VectorType resultType = getResultType(); 4168 int64_t rank = resultType.getRank(); 4169 if (vectorType.getRank() != rank) 4170 return emitOpError("vector result rank mismatch: ") << rank; 4171 // Verify transposition array. 4172 auto transpAttr = transp().getValue(); 4173 int64_t size = transpAttr.size(); 4174 if (rank != size) 4175 return emitOpError("transposition length mismatch: ") << size; 4176 SmallVector<bool, 8> seen(rank, false); 4177 for (const auto &ta : llvm::enumerate(transpAttr)) { 4178 int64_t i = ta.value().cast<IntegerAttr>().getInt(); 4179 if (i < 0 || i >= rank) 4180 return emitOpError("transposition index out of range: ") << i; 4181 if (seen[i]) 4182 return emitOpError("duplicate position index: ") << i; 4183 seen[i] = true; 4184 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) 4185 return emitOpError("dimension size mismatch at: ") << i; 4186 } 4187 return success(); 4188 } 4189 4190 namespace { 4191 4192 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. 4193 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { 4194 public: 4195 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 4196 4197 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 4198 PatternRewriter &rewriter) const override { 4199 // Wrapper around vector::TransposeOp::getTransp() for cleaner code. 4200 auto getPermutation = [](vector::TransposeOp transpose) { 4201 SmallVector<int64_t, 4> permutation; 4202 transpose.getTransp(permutation); 4203 return permutation; 4204 }; 4205 4206 // Composes two permutations: result[i] = permutation1[permutation2[i]]. 4207 auto composePermutations = [](ArrayRef<int64_t> permutation1, 4208 ArrayRef<int64_t> permutation2) { 4209 SmallVector<int64_t, 4> result; 4210 for (auto index : permutation2) 4211 result.push_back(permutation1[index]); 4212 return result; 4213 }; 4214 4215 // Return if the input of 'transposeOp' is not defined by another transpose. 4216 vector::TransposeOp parentTransposeOp = 4217 transposeOp.vector().getDefiningOp<vector::TransposeOp>(); 4218 if (!parentTransposeOp) 4219 return failure(); 4220 4221 SmallVector<int64_t, 4> permutation = composePermutations( 4222 getPermutation(parentTransposeOp), getPermutation(transposeOp)); 4223 // Replace 'transposeOp' with a new transpose operation. 4224 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 4225 transposeOp, transposeOp.getResult().getType(), 4226 parentTransposeOp.vector(), 4227 vector::getVectorSubscriptAttr(rewriter, permutation)); 4228 return success(); 4229 } 4230 }; 4231 4232 } // namespace 4233 4234 void vector::TransposeOp::getCanonicalizationPatterns( 4235 RewritePatternSet &results, MLIRContext *context) { 4236 results.add<TransposeFolder>(context); 4237 } 4238 4239 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) { 4240 populateFromInt64AttrArray(transp(), results); 4241 } 4242 4243 //===----------------------------------------------------------------------===// 4244 // ConstantMaskOp 4245 //===----------------------------------------------------------------------===// 4246 4247 LogicalResult ConstantMaskOp::verify() { 4248 auto resultType = getResult().getType().cast<VectorType>(); 4249 // Check the corner case of 0-D vectors first. 4250 if (resultType.getRank() == 0) { 4251 if (mask_dim_sizes().size() != 1) 4252 return emitError("array attr must have length 1 for 0-D vectors"); 4253 auto dim = mask_dim_sizes()[0].cast<IntegerAttr>().getInt(); 4254 if (dim != 0 && dim != 1) 4255 return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); 4256 return success(); 4257 } 4258 4259 // Verify that array attr size matches the rank of the vector result. 4260 if (static_cast<int64_t>(mask_dim_sizes().size()) != resultType.getRank()) 4261 return emitOpError( 4262 "must specify array attr of size equal vector result rank"); 4263 // Verify that each array attr element is in bounds of corresponding vector 4264 // result dimension size. 4265 auto resultShape = resultType.getShape(); 4266 SmallVector<int64_t, 4> maskDimSizes; 4267 for (const auto &it : llvm::enumerate(mask_dim_sizes())) { 4268 int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); 4269 if (attrValue < 0 || attrValue > resultShape[it.index()]) 4270 return emitOpError( 4271 "array attr of size out of bounds of vector result dimension size"); 4272 maskDimSizes.push_back(attrValue); 4273 } 4274 // Verify that if one mask dim size is zero, they all should be zero (because 4275 // the mask region is a conjunction of each mask dimension interval). 4276 bool anyZeros = llvm::is_contained(maskDimSizes, 0); 4277 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); 4278 if (anyZeros && !allZeros) 4279 return emitOpError("expected all mask dim sizes to be zeros, " 4280 "as a result of conjunction with zero mask dim"); 4281 return success(); 4282 } 4283 4284 //===----------------------------------------------------------------------===// 4285 // CreateMaskOp 4286 //===----------------------------------------------------------------------===// 4287 4288 LogicalResult CreateMaskOp::verify() { 4289 auto vectorType = getResult().getType().cast<VectorType>(); 4290 // Verify that an operand was specified for each result vector each dimension. 4291 if (vectorType.getRank() == 0) { 4292 if (getNumOperands() != 1) 4293 return emitOpError( 4294 "must specify exactly one operand for 0-D create_mask"); 4295 } else if (getNumOperands() != 4296 getResult().getType().cast<VectorType>().getRank()) { 4297 return emitOpError( 4298 "must specify an operand for each result vector dimension"); 4299 } 4300 return success(); 4301 } 4302 4303 namespace { 4304 4305 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. 4306 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { 4307 public: 4308 using OpRewritePattern<CreateMaskOp>::OpRewritePattern; 4309 4310 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, 4311 PatternRewriter &rewriter) const override { 4312 // Return if any of 'createMaskOp' operands are not defined by a constant. 4313 auto isNotDefByConstant = [](Value operand) { 4314 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 4315 }; 4316 if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant)) 4317 return failure(); 4318 // Gather constant mask dimension sizes. 4319 SmallVector<int64_t, 4> maskDimSizes; 4320 for (auto it : llvm::zip(createMaskOp.operands(), 4321 createMaskOp.getType().getShape())) { 4322 auto *defOp = std::get<0>(it).getDefiningOp(); 4323 int64_t maxDimSize = std::get<1>(it); 4324 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value(); 4325 dimSize = std::min(dimSize, maxDimSize); 4326 // If one of dim sizes is zero, set all dims to zero. 4327 if (dimSize <= 0) { 4328 maskDimSizes.assign(createMaskOp.getType().getRank(), 0); 4329 break; 4330 } 4331 maskDimSizes.push_back(dimSize); 4332 } 4333 // Replace 'createMaskOp' with ConstantMaskOp. 4334 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 4335 createMaskOp, createMaskOp.getResult().getType(), 4336 vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); 4337 return success(); 4338 } 4339 }; 4340 4341 } // namespace 4342 4343 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, 4344 MLIRContext *context) { 4345 results.add<CreateMaskFolder>(context); 4346 } 4347 4348 //===----------------------------------------------------------------------===// 4349 // ScanOp 4350 //===----------------------------------------------------------------------===// 4351 4352 LogicalResult ScanOp::verify() { 4353 VectorType srcType = getSourceType(); 4354 VectorType initialType = getInitialValueType(); 4355 // Check reduction dimension < rank. 4356 int64_t srcRank = srcType.getRank(); 4357 int64_t reductionDim = reduction_dim(); 4358 if (reductionDim >= srcRank) 4359 return emitOpError("reduction dimension ") 4360 << reductionDim << " has to be less than " << srcRank; 4361 4362 // Check that rank(initial_value) = rank(src) - 1. 4363 int64_t initialValueRank = initialType.getRank(); 4364 if (initialValueRank != srcRank - 1) 4365 return emitOpError("initial value rank ") 4366 << initialValueRank << " has to be equal to " << srcRank - 1; 4367 4368 // Check shapes of initial value and src. 4369 ArrayRef<int64_t> srcShape = srcType.getShape(); 4370 ArrayRef<int64_t> initialValueShapes = initialType.getShape(); 4371 SmallVector<int64_t> expectedShape; 4372 for (int i = 0; i < srcRank; i++) { 4373 if (i != reductionDim) 4374 expectedShape.push_back(srcShape[i]); 4375 } 4376 if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape), 4377 [](std::tuple<int64_t, int64_t> s) { 4378 return std::get<0>(s) != std::get<1>(s); 4379 })) { 4380 return emitOpError("incompatible input/initial value shapes"); 4381 } 4382 4383 return success(); 4384 } 4385 4386 void mlir::vector::populateVectorToVectorCanonicalizationPatterns( 4387 RewritePatternSet &patterns) { 4388 patterns 4389 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, 4390 ScatterFolder, ExpandLoadFolder, CompressStoreFolder, 4391 StridedSliceConstantMaskFolder, TransposeFolder>( 4392 patterns.getContext()); 4393 } 4394 4395 //===----------------------------------------------------------------------===// 4396 // SplatOp 4397 //===----------------------------------------------------------------------===// 4398 4399 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { 4400 auto constOperand = operands.front(); 4401 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>()) 4402 return {}; 4403 4404 // SplatElementsAttr::get treats single value for second arg as being a splat. 4405 return SplatElementsAttr::get(getType(), {constOperand}); 4406 } 4407 4408 //===----------------------------------------------------------------------===// 4409 // TableGen'd op method definitions 4410 //===----------------------------------------------------------------------===// 4411 4412 #define GET_OP_CLASSES 4413 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 4414