1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements convenience types for working with super-vectorization 10 // operations, in particular super-vector loads and stores. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/IndexingUtils.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/BuiltinOps.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/DialectImplementation.h" 29 #include "mlir/IR/OpImplementation.h" 30 #include "mlir/IR/PatternMatch.h" 31 #include "mlir/IR/TypeUtilities.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Support/MathExtras.h" 34 #include "llvm/ADT/StringSet.h" 35 #include "llvm/ADT/bit.h" 36 #include <numeric> 37 38 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" 39 // Pull in all enum type and utility function definitions. 40 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" 41 42 using namespace mlir; 43 using namespace mlir::vector; 44 45 /// Helper enum to classify mask value. 46 enum class MaskFormat { 47 AllTrue = 0, 48 AllFalse = 1, 49 Unknown = 2, 50 }; 51 52 /// Helper method to classify a 1-D mask value. Currently, the method 53 /// looks "under the hood" of a constant value with dense attributes 54 /// and a constant mask operation (since the client may be called at 55 /// various stages during progressive lowering). 56 static MaskFormat get1DMaskFormat(Value mask) { 57 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) { 58 // Inspect constant dense values. We count up for bits that 59 // are set, count down for bits that are cleared, and bail 60 // when a mix is detected. 61 if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) { 62 int64_t val = 0; 63 for (bool b : denseElts.getValues<bool>()) 64 if (b && val >= 0) 65 val++; 66 else if (!b && val <= 0) 67 val--; 68 else 69 return MaskFormat::Unknown; 70 if (val > 0) 71 return MaskFormat::AllTrue; 72 if (val < 0) 73 return MaskFormat::AllFalse; 74 } 75 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { 76 // Inspect constant mask index. If the index exceeds the 77 // dimension size, all bits are set. If the index is zero 78 // or less, no bits are set. 79 ArrayAttr masks = m.mask_dim_sizes(); 80 assert(masks.size() == 1); 81 int64_t i = masks[0].cast<IntegerAttr>().getInt(); 82 int64_t u = m.getType().getDimSize(0); 83 if (i >= u) 84 return MaskFormat::AllTrue; 85 if (i <= 0) 86 return MaskFormat::AllFalse; 87 } 88 return MaskFormat::Unknown; 89 } 90 91 // Helper for verifying combining kinds in contractions and reductions. 92 static bool isSupportedCombiningKind(CombiningKind combiningKind, 93 Type elementType) { 94 switch (combiningKind) { 95 case CombiningKind::ADD: 96 case CombiningKind::MUL: 97 return elementType.isIntOrIndexOrFloat(); 98 case CombiningKind::MINUI: 99 case CombiningKind::MINSI: 100 case CombiningKind::MAXUI: 101 case CombiningKind::MAXSI: 102 case CombiningKind::AND: 103 case CombiningKind::OR: 104 case CombiningKind::XOR: 105 return elementType.isIntOrIndex(); 106 case CombiningKind::MINF: 107 case CombiningKind::MAXF: 108 return elementType.isa<FloatType>(); 109 } 110 return false; 111 } 112 113 /// Return true if the last dimension of the MemRefType has unit stride. Also 114 /// return true for memrefs with no strides. 115 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { 116 int64_t offset; 117 SmallVector<int64_t> strides; 118 auto successStrides = getStridesAndOffset(type, strides, offset); 119 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 120 } 121 122 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, 123 VectorType vectorType) { 124 int64_t elementVectorRank = 0; 125 VectorType elementVectorType = 126 shapedType.getElementType().dyn_cast<VectorType>(); 127 if (elementVectorType) 128 elementVectorRank += elementVectorType.getRank(); 129 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>. 130 // TODO: replace once we have 0-d vectors. 131 if (shapedType.getRank() == 0 && 132 vectorType.getShape() == ArrayRef<int64_t>{1}) 133 return AffineMap::get( 134 /*numDims=*/0, /*numSymbols=*/0, 135 getAffineConstantExpr(0, shapedType.getContext())); 136 return AffineMap::getMinorIdentityMap( 137 shapedType.getRank(), vectorType.getRank() - elementVectorRank, 138 shapedType.getContext()); 139 } 140 141 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, 142 vector::TransferReadOp read) { 143 return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && 144 defWrite.indices() == read.indices() && 145 defWrite.getVectorType() == read.getVectorType() && 146 defWrite.permutation_map() == read.permutation_map(); 147 } 148 149 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, 150 vector::TransferWriteOp priorWrite) { 151 return priorWrite.indices() == write.indices() && 152 priorWrite.mask() == write.mask() && 153 priorWrite.getVectorType() == write.getVectorType() && 154 priorWrite.permutation_map() == write.permutation_map(); 155 } 156 157 bool mlir::vector::isDisjointTransferIndices( 158 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { 159 // For simplicity only look at transfer of same type. 160 if (transferA.getVectorType() != transferB.getVectorType()) 161 return false; 162 unsigned rankOffset = transferA.getLeadingShapedRank(); 163 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { 164 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>(); 165 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>(); 166 // If any of the indices are dynamic we cannot prove anything. 167 if (!indexA || !indexB) 168 continue; 169 170 if (i < rankOffset) { 171 // For leading dimensions, if we can prove that index are different we 172 // know we are accessing disjoint slices. 173 if (indexA.getValue().cast<IntegerAttr>().getInt() != 174 indexB.getValue().cast<IntegerAttr>().getInt()) 175 return true; 176 } else { 177 // For this dimension, we slice a part of the memref we need to make sure 178 // the intervals accessed don't overlap. 179 int64_t distance = 180 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() - 181 indexB.getValue().cast<IntegerAttr>().getInt()); 182 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) 183 return true; 184 } 185 } 186 return false; 187 } 188 189 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, 190 VectorTransferOpInterface transferB) { 191 if (transferA.source() != transferB.source()) 192 return false; 193 return isDisjointTransferIndices(transferA, transferB); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // CombiningKindAttr 198 //===----------------------------------------------------------------------===// 199 200 namespace mlir { 201 namespace vector { 202 namespace detail { 203 struct BitmaskEnumStorage : public AttributeStorage { 204 using KeyTy = uint64_t; 205 206 BitmaskEnumStorage(KeyTy val) : value(val) {} 207 208 bool operator==(const KeyTy &key) const { return value == key; } 209 210 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, 211 const KeyTy &key) { 212 return new (allocator.allocate<BitmaskEnumStorage>()) 213 BitmaskEnumStorage(key); 214 } 215 216 KeyTy value = 0; 217 }; 218 } // namespace detail 219 } // namespace vector 220 } // namespace mlir 221 222 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind, 223 MLIRContext *context) { 224 return Base::get(context, static_cast<uint64_t>(kind)); 225 } 226 227 CombiningKind CombiningKindAttr::getKind() const { 228 return static_cast<CombiningKind>(getImpl()->value); 229 } 230 231 static constexpr const CombiningKind combiningKindsList[] = { 232 // clang-format off 233 CombiningKind::ADD, 234 CombiningKind::MUL, 235 CombiningKind::MINUI, 236 CombiningKind::MINSI, 237 CombiningKind::MINF, 238 CombiningKind::MAXUI, 239 CombiningKind::MAXSI, 240 CombiningKind::MAXF, 241 CombiningKind::AND, 242 CombiningKind::OR, 243 CombiningKind::XOR, 244 // clang-format on 245 }; 246 247 void CombiningKindAttr::print(AsmPrinter &printer) const { 248 printer << "<"; 249 auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { 250 return bitEnumContains(this->getKind(), kind); 251 }); 252 llvm::interleaveComma(kinds, printer, 253 [&](auto kind) { printer << stringifyEnum(kind); }); 254 printer << ">"; 255 } 256 257 Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) { 258 if (failed(parser.parseLess())) 259 return {}; 260 261 StringRef elemName; 262 if (failed(parser.parseKeyword(&elemName))) 263 return {}; 264 265 auto kind = symbolizeCombiningKind(elemName); 266 if (!kind) { 267 parser.emitError(parser.getNameLoc(), "Unknown combining kind: ") 268 << elemName; 269 return {}; 270 } 271 272 if (failed(parser.parseGreater())) 273 return {}; 274 275 return CombiningKindAttr::get(kind.getValue(), parser.getContext()); 276 } 277 278 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser, 279 Type type) const { 280 StringRef attrKind; 281 if (parser.parseKeyword(&attrKind)) 282 return {}; 283 284 if (attrKind == "kind") 285 return CombiningKindAttr::parse(parser, {}); 286 287 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; 288 return {}; 289 } 290 291 void VectorDialect::printAttribute(Attribute attr, 292 DialectAsmPrinter &os) const { 293 if (auto ck = attr.dyn_cast<CombiningKindAttr>()) { 294 os << "kind"; 295 ck.print(os); 296 return; 297 } 298 llvm_unreachable("Unknown attribute type"); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // VectorDialect 303 //===----------------------------------------------------------------------===// 304 305 void VectorDialect::initialize() { 306 addAttributes<CombiningKindAttr>(); 307 308 addOperations< 309 #define GET_OP_LIST 310 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 311 >(); 312 } 313 314 /// Materialize a single constant operation from a given attribute value with 315 /// the desired resultant type. 316 Operation *VectorDialect::materializeConstant(OpBuilder &builder, 317 Attribute value, Type type, 318 Location loc) { 319 return builder.create<arith::ConstantOp>(loc, type, value); 320 } 321 322 IntegerType vector::getVectorSubscriptType(Builder &builder) { 323 return builder.getIntegerType(64); 324 } 325 326 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, 327 ArrayRef<int64_t> values) { 328 return builder.getI64ArrayAttr(values); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // MultiDimReductionOp 333 //===----------------------------------------------------------------------===// 334 335 void vector::MultiDimReductionOp::build(OpBuilder &builder, 336 OperationState &result, Value source, 337 ArrayRef<bool> reductionMask, 338 CombiningKind kind) { 339 result.addOperands(source); 340 auto sourceVectorType = source.getType().cast<VectorType>(); 341 auto targetType = MultiDimReductionOp::inferDestType( 342 sourceVectorType.getShape(), reductionMask, 343 sourceVectorType.getElementType()); 344 result.addTypes(targetType); 345 346 SmallVector<int64_t> reductionDims; 347 for (const auto &en : llvm::enumerate(reductionMask)) 348 if (en.value()) 349 reductionDims.push_back(en.index()); 350 result.addAttribute(getReductionDimsAttrStrName(), 351 builder.getI64ArrayAttr(reductionDims)); 352 result.addAttribute(getKindAttrStrName(), 353 CombiningKindAttr::get(kind, builder.getContext())); 354 } 355 356 LogicalResult MultiDimReductionOp::verify() { 357 auto reductionMask = getReductionMask(); 358 auto targetType = MultiDimReductionOp::inferDestType( 359 getSourceVectorType().getShape(), reductionMask, 360 getSourceVectorType().getElementType()); 361 // TODO: update to support 0-d vectors when available. 362 if (targetType != getDestType()) 363 return emitError("invalid output vector type: ") 364 << getDestType() << " (expected: " << targetType << ")"; 365 return success(); 366 } 367 368 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) { 369 // Single parallel dim, this is a noop. 370 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) 371 return source(); 372 return {}; 373 } 374 375 //===----------------------------------------------------------------------===// 376 // ReductionOp 377 //===----------------------------------------------------------------------===// 378 379 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, 380 CombiningKind kind, Value vector) { 381 build(builder, result, kind, vector, /*acc=*/Value()); 382 } 383 384 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, 385 CombiningKind kind, Value vector, Value acc) { 386 build(builder, result, vector.getType().cast<VectorType>().getElementType(), 387 kind, vector, acc); 388 } 389 390 LogicalResult ReductionOp::verify() { 391 // Verify for 1-D vector. 392 int64_t rank = getVectorType().getRank(); 393 if (rank != 1) 394 return emitOpError("unsupported reduction rank: ") << rank; 395 396 // Verify supported reduction kind. 397 Type eltType = dest().getType(); 398 if (!isSupportedCombiningKind(kind(), eltType)) 399 return emitOpError("unsupported reduction type '") 400 << eltType << "' for kind '" << stringifyCombiningKind(kind()) 401 << "'"; 402 403 // Verify optional accumulator. 404 if (acc()) { 405 if (kind() != CombiningKind::ADD && kind() != CombiningKind::MUL) 406 return emitOpError("no accumulator for reduction kind: ") 407 << stringifyCombiningKind(kind()); 408 if (!eltType.isa<FloatType>()) 409 return emitOpError("no accumulator for type: ") << eltType; 410 } 411 412 return success(); 413 } 414 415 ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { 416 SmallVector<OpAsmParser::OperandType, 2> operandsInfo; 417 Type redType; 418 Type resType; 419 CombiningKindAttr kindAttr; 420 if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind", 421 result.attributes) || 422 parser.parseComma() || parser.parseOperandList(operandsInfo) || 423 parser.parseColonType(redType) || 424 parser.parseKeywordType("into", resType) || 425 (!operandsInfo.empty() && 426 parser.resolveOperand(operandsInfo[0], redType, result.operands)) || 427 (operandsInfo.size() > 1 && 428 parser.resolveOperand(operandsInfo[1], resType, result.operands)) || 429 parser.addTypeToList(resType, result.types)) 430 return failure(); 431 if (operandsInfo.empty() || operandsInfo.size() > 2) 432 return parser.emitError(parser.getNameLoc(), 433 "unsupported number of operands"); 434 return success(); 435 } 436 437 void ReductionOp::print(OpAsmPrinter &p) { 438 p << " "; 439 kindAttr().print(p); 440 p << ", " << vector(); 441 if (acc()) 442 p << ", " << acc(); 443 p << " : " << vector().getType() << " into " << dest().getType(); 444 } 445 446 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, 447 OpBuilder &builder, Location loc, 448 Value vector) { 449 switch (op) { 450 case arith::AtomicRMWKind::addf: 451 case arith::AtomicRMWKind::addi: 452 return builder.create<vector::ReductionOp>(vector.getLoc(), 453 CombiningKind::ADD, vector); 454 case arith::AtomicRMWKind::mulf: 455 case arith::AtomicRMWKind::muli: 456 return builder.create<vector::ReductionOp>(vector.getLoc(), 457 CombiningKind::MUL, vector); 458 case arith::AtomicRMWKind::minf: 459 return builder.create<vector::ReductionOp>(vector.getLoc(), 460 CombiningKind::MINF, vector); 461 case arith::AtomicRMWKind::mins: 462 return builder.create<vector::ReductionOp>(vector.getLoc(), 463 CombiningKind::MINSI, vector); 464 case arith::AtomicRMWKind::minu: 465 return builder.create<vector::ReductionOp>(vector.getLoc(), 466 CombiningKind::MINUI, vector); 467 case arith::AtomicRMWKind::maxf: 468 return builder.create<vector::ReductionOp>(vector.getLoc(), 469 CombiningKind::MAXF, vector); 470 case arith::AtomicRMWKind::maxs: 471 return builder.create<vector::ReductionOp>(vector.getLoc(), 472 CombiningKind::MAXSI, vector); 473 case arith::AtomicRMWKind::maxu: 474 return builder.create<vector::ReductionOp>(vector.getLoc(), 475 CombiningKind::MAXUI, vector); 476 // TODO: Add remaining reduction operations. 477 default: 478 (void)emitOptionalError(loc, "Reduction operation type not supported"); 479 break; 480 } 481 return nullptr; 482 } 483 484 //===----------------------------------------------------------------------===// 485 // ContractionOp 486 //===----------------------------------------------------------------------===// 487 488 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 489 Value lhs, Value rhs, Value acc, 490 ArrayRef<ArrayRef<AffineExpr>> indexingExprs, 491 ArrayRef<StringRef> iteratorTypes) { 492 result.addOperands({lhs, rhs, acc}); 493 result.addTypes(acc.getType()); 494 result.addAttribute(::mlir::getIndexingMapsAttrName(), 495 builder.getAffineMapArrayAttr( 496 AffineMap::inferFromExprList(indexingExprs))); 497 result.addAttribute(::mlir::getIteratorTypesAttrName(), 498 builder.getStrArrayAttr(iteratorTypes)); 499 } 500 501 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 502 Value lhs, Value rhs, Value acc, 503 ArrayAttr indexingMaps, 504 ArrayAttr iteratorTypes) { 505 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes, 506 ContractionOp::getDefaultKind()); 507 } 508 509 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 510 Value lhs, Value rhs, Value acc, 511 ArrayAttr indexingMaps, 512 ArrayAttr iteratorTypes, CombiningKind kind) { 513 result.addOperands({lhs, rhs, acc}); 514 result.addTypes(acc.getType()); 515 result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps); 516 result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes); 517 result.addAttribute(ContractionOp::getKindAttrStrName(), 518 CombiningKindAttr::get(kind, builder.getContext())); 519 } 520 521 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { 522 OpAsmParser::OperandType lhsInfo; 523 OpAsmParser::OperandType rhsInfo; 524 OpAsmParser::OperandType accInfo; 525 SmallVector<OpAsmParser::OperandType, 2> masksInfo; 526 SmallVector<Type, 2> types; 527 Type resultType; 528 auto loc = parser.getCurrentLocation(); 529 DictionaryAttr dictAttr; 530 // TODO: Unify linalg op attribute parsing. 531 if (parser.parseAttribute(dictAttr, "_", result.attributes) || 532 parser.parseOperand(lhsInfo) || parser.parseComma() || 533 parser.parseOperand(rhsInfo) || parser.parseComma() || 534 parser.parseOperand(accInfo) || 535 parser.parseTrailingOperandList(masksInfo) || 536 parser.parseOptionalAttrDict(result.attributes) || 537 parser.parseColonTypeList(types) || 538 parser.parseKeywordType("into", resultType) || 539 parser.resolveOperand(lhsInfo, types[0], result.operands) || 540 parser.resolveOperand(rhsInfo, types[1], result.operands) || 541 parser.resolveOperand(accInfo, resultType, result.operands) || 542 parser.addTypeToList(resultType, result.types)) 543 return failure(); 544 result.attributes.assign(dictAttr.getValue().begin(), 545 dictAttr.getValue().end()); 546 if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { 547 result.addAttribute(ContractionOp::getKindAttrStrName(), 548 CombiningKindAttr::get(ContractionOp::getDefaultKind(), 549 result.getContext())); 550 } 551 if (masksInfo.empty()) 552 return success(); 553 if (masksInfo.size() != 2) 554 return parser.emitError(parser.getNameLoc(), 555 "expected zero or exactly 2 vector mask operands"); 556 auto lhsType = types[0].cast<VectorType>(); 557 auto rhsType = types[1].cast<VectorType>(); 558 auto maskElementType = parser.getBuilder().getI1Type(); 559 std::array<Type, 2> maskTypes = { 560 VectorType::Builder(lhsType).setElementType(maskElementType), 561 VectorType::Builder(rhsType).setElementType(maskElementType)}; 562 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) 563 return failure(); 564 return success(); 565 } 566 567 void ContractionOp::print(OpAsmPrinter &p) { 568 // TODO: Unify printing code with linalg ops. 569 auto attrNames = getTraitAttrNames(); 570 llvm::StringSet<> traitAttrsSet; 571 traitAttrsSet.insert(attrNames.begin(), attrNames.end()); 572 SmallVector<NamedAttribute, 8> attrs; 573 for (auto attr : (*this)->getAttrs()) 574 if (traitAttrsSet.count(attr.getName().strref()) > 0) 575 attrs.push_back(attr); 576 577 auto dictAttr = DictionaryAttr::get(getContext(), attrs); 578 p << " " << dictAttr << " " << lhs() << ", "; 579 p << rhs() << ", " << acc(); 580 if (masks().size() == 2) 581 p << ", " << masks(); 582 583 p.printOptionalAttrDict((*this)->getAttrs(), attrNames); 584 p << " : " << lhs().getType() << ", " << rhs().getType() << " into " 585 << getResultType(); 586 } 587 588 static bool verifyDimMap(VectorType lhsType, VectorType rhsType, 589 const std::vector<std::pair<int64_t, int64_t>> &map) { 590 for (auto &dimPair : map) { 591 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || 592 dimPair.second < 0 || dimPair.second >= rhsType.getRank() || 593 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) 594 return false; 595 } 596 return true; 597 } 598 599 static LogicalResult verifyOutputShape( 600 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, 601 Type resType, 602 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, 603 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { 604 DenseSet<int64_t> lhsContractingDimSet; 605 DenseSet<int64_t> rhsContractingDimSet; 606 for (auto &dimPair : contractingDimMap) { 607 lhsContractingDimSet.insert(dimPair.first); 608 rhsContractingDimSet.insert(dimPair.second); 609 } 610 DenseSet<int64_t> rhsBatchDimSet; 611 for (auto &dimPair : batchDimMap) 612 rhsBatchDimSet.insert(dimPair.second); 613 614 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. 615 SmallVector<int64_t, 4> expectedResultDims; 616 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { 617 if (lhsContractingDimSet.count(i) > 0) 618 continue; 619 expectedResultDims.push_back(lhsType.getDimSize(i)); 620 } 621 622 // Add free dimensions from 'rhsType' to 'expectedResultDims'. 623 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { 624 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) 625 continue; 626 expectedResultDims.push_back(rhsType.getDimSize(i)); 627 } 628 629 // Verify 'expectedResultDims'. 630 if (expectedResultDims.empty()) { 631 // No batch or free dimension implies a scalar result. 632 if (resType.isa<VectorType>() || accType.isa<VectorType>()) 633 return op.emitOpError("invalid accumulator/result vector shape"); 634 } else { 635 // At least one batch or free dimension implies a vector result. 636 auto resVectorType = resType.dyn_cast<VectorType>(); 637 auto accVectorType = accType.dyn_cast<VectorType>(); 638 if (!resVectorType || !accVectorType) 639 return op.emitOpError("invalid accumulator/result vector shape"); 640 641 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector 642 // types fully define the result vector type. This assumes the affine maps 643 // are well-formed, which must have been verified already. 644 MLIRContext *ctx = op.getContext(); 645 AffineMap lhsMap = op.getIndexingMaps()[0]; 646 AffineMap rhsMap = op.getIndexingMaps()[1]; 647 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); 648 for (auto pair : 649 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { 650 VectorType v = pair.first; 651 auto map = pair.second; 652 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { 653 unsigned pos = map.getDimPosition(idx); 654 if (!extents[pos]) 655 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); 656 } 657 } 658 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) && 659 "expected extent along all dimensions."); 660 661 AffineMap resMap = op.getIndexingMaps()[2]; 662 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), 663 /*symCount=*/0, extents, ctx); 664 // Compose the resMap with the extentsMap, which is a constant map. 665 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); 666 assert(llvm::all_of( 667 expectedMap.getResults(), 668 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && 669 "expected constant extent along all dimensions."); 670 // Extract the expected shape and build the type. 671 auto expectedShape = llvm::to_vector<4>( 672 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { 673 return e.cast<AffineConstantExpr>().getValue(); 674 })); 675 auto expected = 676 VectorType::get(expectedShape, resVectorType.getElementType()); 677 if (resVectorType != expected || accVectorType != expected) 678 return op.emitOpError( 679 "invalid accumulator/result vector shape, expected: ") 680 << expected; 681 } 682 return success(); 683 } 684 685 LogicalResult ContractionOp::verify() { 686 auto lhsType = getLhsType(); 687 auto rhsType = getRhsType(); 688 auto accType = getAccType(); 689 auto resType = getResultType(); 690 691 // Verify that an indexing map was specified for each vector operand. 692 if (indexing_maps().size() != 3) 693 return emitOpError("expected an indexing map for each vector operand"); 694 695 // Verify that each index map has 'numIterators' inputs, no symbols, and 696 // that the number of map outputs equals the rank of its associated 697 // vector operand. 698 unsigned numIterators = iterator_types().getValue().size(); 699 for (const auto &it : llvm::enumerate(indexing_maps())) { 700 auto index = it.index(); 701 auto map = it.value(); 702 if (map.getNumSymbols() != 0) 703 return emitOpError("expected indexing map ") 704 << index << " to have no symbols"; 705 auto vectorType = getOperand(index).getType().dyn_cast<VectorType>(); 706 unsigned rank = vectorType ? vectorType.getShape().size() : 0; 707 // Verify that the map has the right number of inputs, outputs, and indices. 708 // This also correctly accounts for (..) -> () for rank-0 results. 709 if (map.getNumDims() != numIterators) 710 return emitOpError("expected indexing map ") 711 << index << " to have " << numIterators << " number of inputs"; 712 if (map.getNumResults() != rank) 713 return emitOpError("expected indexing map ") 714 << index << " to have " << rank << " number of outputs"; 715 if (!map.isProjectedPermutation()) 716 return emitOpError("expected indexing map ") 717 << index << " to be a projected permutation of its inputs"; 718 } 719 720 auto contractingDimMap = getContractingDimMap(); 721 auto batchDimMap = getBatchDimMap(); 722 723 // Verify at least one contracting dimension pair was specified. 724 if (contractingDimMap.empty()) 725 return emitOpError("expected at least one contracting dimension pair"); 726 727 // Verify contracting dimension map was properly constructed. 728 if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) 729 return emitOpError("invalid contracting dimension map"); 730 731 // Verify batch dimension map was properly constructed. 732 if (!verifyDimMap(lhsType, rhsType, batchDimMap)) 733 return emitOpError("invalid batch dimension map"); 734 735 // Verify 'accType' and 'resType' shape. 736 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, 737 contractingDimMap, batchDimMap))) 738 return failure(); 739 740 // Verify that either two vector masks are set or none are set. 741 auto lhsMaskType = getLHSVectorMaskType(); 742 auto rhsMaskType = getRHSVectorMaskType(); 743 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) 744 return emitOpError("invalid number of vector masks specified"); 745 if (lhsMaskType && rhsMaskType) { 746 // Verify mask rank == argument rank. 747 if (lhsMaskType.getShape().size() != lhsType.getShape().size() || 748 rhsMaskType.getShape().size() != rhsType.getShape().size()) 749 return emitOpError("invalid vector mask rank"); 750 } 751 752 // Verify supported combining kind. 753 auto vectorType = resType.dyn_cast<VectorType>(); 754 auto elementType = vectorType ? vectorType.getElementType() : resType; 755 if (!isSupportedCombiningKind(kind(), elementType)) 756 return emitOpError("unsupported contraction type"); 757 758 return success(); 759 } 760 761 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() { 762 static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(), 763 ::mlir::getIteratorTypesAttrName(), 764 ContractionOp::getKindAttrStrName()}; 765 return llvm::makeArrayRef(names); 766 } 767 768 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { 769 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) 770 if (targetExpr == map.getResult(i)) 771 return i; 772 return -1; 773 } 774 775 static std::vector<std::pair<int64_t, int64_t>> 776 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, 777 StringRef targetIteratorTypeName, MLIRContext *context) { 778 std::vector<std::pair<int64_t, int64_t>> dimMap; 779 for (const auto &it : llvm::enumerate(iteratorTypes)) { 780 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 781 if (iteratorTypeName != targetIteratorTypeName) 782 continue; 783 // Search lhs/rhs map results for 'targetExpr'. 784 auto targetExpr = getAffineDimExpr(it.index(), context); 785 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); 786 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); 787 if (lhsDim >= 0 && rhsDim >= 0) 788 dimMap.emplace_back(lhsDim, rhsDim); 789 } 790 return dimMap; 791 } 792 793 void ContractionOp::getIterationBounds( 794 SmallVectorImpl<int64_t> &iterationBounds) { 795 auto lhsShape = getLhsType().getShape(); 796 auto resVectorType = getResultType().dyn_cast<VectorType>(); 797 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 798 SmallVector<int64_t, 2> iterationShape; 799 for (const auto &it : llvm::enumerate(iterator_types())) { 800 // Search lhs/rhs map results for 'targetExpr'. 801 auto targetExpr = getAffineDimExpr(it.index(), getContext()); 802 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 803 if (iteratorTypeName == getReductionIteratorTypeName()) { 804 // Get reduction dim size from lhs shape (same size in rhsShape). 805 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); 806 assert(lhsDimIndex >= 0); 807 iterationBounds.push_back(lhsShape[lhsDimIndex]); 808 continue; 809 } 810 // Get parallel dimension size from result shape. 811 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); 812 assert(resDimIndex >= 0); 813 assert(resVectorType != nullptr); 814 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); 815 } 816 } 817 818 void ContractionOp::getIterationIndexMap( 819 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { 820 unsigned numMaps = indexing_maps().size(); 821 iterationIndexMap.resize(numMaps); 822 for (const auto &it : llvm::enumerate(indexing_maps())) { 823 auto index = it.index(); 824 auto map = it.value(); 825 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 826 auto dim = map.getResult(i).cast<AffineDimExpr>(); 827 iterationIndexMap[index][dim.getPosition()] = i; 828 } 829 } 830 } 831 832 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { 833 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 834 return getDimMap(indexingMaps, iterator_types(), 835 getReductionIteratorTypeName(), getContext()); 836 } 837 838 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { 839 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 840 return getDimMap(indexingMaps, iterator_types(), 841 getParallelIteratorTypeName(), getContext()); 842 } 843 844 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { 845 SmallVector<int64_t, 4> shape; 846 getIterationBounds(shape); 847 return shape; 848 } 849 850 /// Return a fused vector::ContractionOp which represents a patterns such as: 851 /// 852 /// ```mlir 853 /// %c0 = vector.constant 0: ... 854 /// %c = vector.contract %a, %b, %c0: ... 855 /// %e = add %c, %d: ... 856 /// ``` 857 /// 858 /// by: 859 /// 860 /// ```mlir 861 /// %e = vector.contract %a, %b, %d: ... 862 /// ``` 863 /// 864 /// Return null if the canonicalization does not apply. 865 // TODO: This should be a folding of Add into Contract in core but while they 866 // live in different dialects, it is not possible without unnatural 867 // dependencies. 868 template <typename AddOpType> 869 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { 870 using OpRewritePattern<AddOpType>::OpRewritePattern; 871 872 LogicalResult matchAndRewrite(AddOpType addOp, 873 PatternRewriter &rewriter) const override { 874 auto canonicalize = [&](Value maybeContraction, 875 Value otherOperand) -> vector::ContractionOp { 876 vector::ContractionOp contractionOp = 877 dyn_cast_or_null<vector::ContractionOp>( 878 maybeContraction.getDefiningOp()); 879 if (!contractionOp) 880 return vector::ContractionOp(); 881 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>( 882 contractionOp.acc().getDefiningOp())) { 883 if (maybeZero.getValue() == 884 rewriter.getZeroAttr(contractionOp.acc().getType())) { 885 BlockAndValueMapping bvm; 886 bvm.map(contractionOp.acc(), otherOperand); 887 auto newContraction = 888 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm)); 889 rewriter.replaceOp(addOp, newContraction.getResult()); 890 return newContraction; 891 } 892 } 893 return vector::ContractionOp(); 894 }; 895 896 Value a = addOp->getOperand(0), b = addOp->getOperand(1); 897 vector::ContractionOp contract = canonicalize(a, b); 898 contract = contract ? contract : canonicalize(b, a); 899 return contract ? success() : failure(); 900 } 901 }; 902 903 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, 904 MLIRContext *context) { 905 results.add<CanonicalizeContractAdd<arith::AddIOp>, 906 CanonicalizeContractAdd<arith::AddFOp>>(context); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // ExtractElementOp 911 //===----------------------------------------------------------------------===// 912 913 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 914 Value source) { 915 result.addOperands({source}); 916 result.addTypes(source.getType().cast<VectorType>().getElementType()); 917 } 918 919 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 920 Value source, Value position) { 921 result.addOperands({source, position}); 922 result.addTypes(source.getType().cast<VectorType>().getElementType()); 923 } 924 925 LogicalResult vector::ExtractElementOp::verify() { 926 VectorType vectorType = getVectorType(); 927 if (vectorType.getRank() == 0) { 928 if (position()) 929 return emitOpError("expected position to be empty with 0-D vector"); 930 return success(); 931 } 932 if (vectorType.getRank() != 1) 933 return emitOpError("unexpected >1 vector rank"); 934 if (!position()) 935 return emitOpError("expected position for 1-D vector"); 936 return success(); 937 } 938 939 //===----------------------------------------------------------------------===// 940 // ExtractOp 941 //===----------------------------------------------------------------------===// 942 943 static Type inferExtractOpResultType(VectorType vectorType, 944 ArrayAttr position) { 945 if (static_cast<int64_t>(position.size()) == vectorType.getRank()) 946 return vectorType.getElementType(); 947 return VectorType::get(vectorType.getShape().drop_front(position.size()), 948 vectorType.getElementType()); 949 } 950 951 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 952 Value source, ArrayRef<int64_t> position) { 953 result.addOperands(source); 954 auto positionAttr = getVectorSubscriptAttr(builder, position); 955 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(), 956 positionAttr)); 957 result.addAttribute(getPositionAttrStrName(), positionAttr); 958 } 959 960 // Convenience builder which assumes the values are constant indices. 961 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 962 Value source, ValueRange position) { 963 SmallVector<int64_t, 4> positionConstants = 964 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 965 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 966 })); 967 build(builder, result, source, positionConstants); 968 } 969 970 void vector::ExtractOp::print(OpAsmPrinter &p) { 971 p << " " << vector() << position(); 972 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 973 p << " : " << vector().getType(); 974 } 975 976 ParseResult vector::ExtractOp::parse(OpAsmParser &parser, 977 OperationState &result) { 978 SMLoc attributeLoc, typeLoc; 979 NamedAttrList attrs; 980 OpAsmParser::OperandType vector; 981 Type type; 982 Attribute attr; 983 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) || 984 parser.parseAttribute(attr, "position", attrs) || 985 parser.parseOptionalAttrDict(attrs) || 986 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type)) 987 return failure(); 988 989 auto vectorType = type.dyn_cast<VectorType>(); 990 if (!vectorType) 991 return parser.emitError(typeLoc, "expected vector type"); 992 993 auto positionAttr = attr.dyn_cast<ArrayAttr>(); 994 if (!positionAttr || 995 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank()) 996 return parser.emitError( 997 attributeLoc, 998 "expected position attribute of rank smaller than vector rank"); 999 1000 Type resType = inferExtractOpResultType(vectorType, positionAttr); 1001 result.attributes = attrs; 1002 return failure(parser.resolveOperand(vector, type, result.operands) || 1003 parser.addTypeToList(resType, result.types)); 1004 } 1005 1006 LogicalResult vector::ExtractOp::verify() { 1007 auto positionAttr = position().getValue(); 1008 if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank())) 1009 return emitOpError( 1010 "expected position attribute of rank smaller than vector rank"); 1011 for (const auto &en : llvm::enumerate(positionAttr)) { 1012 auto attr = en.value().dyn_cast<IntegerAttr>(); 1013 if (!attr || attr.getInt() < 0 || 1014 attr.getInt() >= getVectorType().getDimSize(en.index())) 1015 return emitOpError("expected position attribute #") 1016 << (en.index() + 1) 1017 << " to be a non-negative integer smaller than the corresponding " 1018 "vector dimension"; 1019 } 1020 return success(); 1021 } 1022 1023 template <typename IntType> 1024 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 1025 return llvm::to_vector<4>(llvm::map_range( 1026 arrayAttr.getAsRange<IntegerAttr>(), 1027 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 1028 } 1029 1030 /// Fold the result of chains of ExtractOp in place by simply concatenating the 1031 /// positions. 1032 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { 1033 if (!extractOp.vector().getDefiningOp<ExtractOp>()) 1034 return failure(); 1035 1036 SmallVector<int64_t, 4> globalPosition; 1037 ExtractOp currentOp = extractOp; 1038 auto extrPos = extractVector<int64_t>(currentOp.position()); 1039 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1040 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) { 1041 currentOp = nextOp; 1042 auto extrPos = extractVector<int64_t>(currentOp.position()); 1043 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1044 } 1045 extractOp.setOperand(currentOp.vector()); 1046 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1047 OpBuilder b(extractOp.getContext()); 1048 std::reverse(globalPosition.begin(), globalPosition.end()); 1049 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1050 b.getI64ArrayAttr(globalPosition)); 1051 return success(); 1052 } 1053 1054 namespace { 1055 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. 1056 /// Walk back a chain of InsertOp/TransposeOp until we hit a match. 1057 /// Compose TransposeOp permutations as we walk back. 1058 /// This helper class keeps an updated extraction position `extractPosition` 1059 /// with extra trailing sentinels. 1060 /// The sentinels encode the internal transposition status of the result vector. 1061 /// As we iterate, extractPosition is permuted and updated. 1062 class ExtractFromInsertTransposeChainState { 1063 public: 1064 ExtractFromInsertTransposeChainState(ExtractOp e); 1065 1066 /// Iterate over producing insert and transpose ops until we find a fold. 1067 Value fold(); 1068 1069 private: 1070 /// Return true if the vector at position `a` is contained within the vector 1071 /// at position `b`. Under insert/extract semantics, this is the same as `a` 1072 /// is a prefix of `b`. 1073 template <typename ContainerA, typename ContainerB> 1074 bool isContainedWithin(const ContainerA &a, const ContainerB &b) { 1075 return a.size() <= b.size() && 1076 std::equal(a.begin(), a.begin() + a.size(), b.begin()); 1077 } 1078 1079 /// Return true if the vector at position `a` intersects the vector at 1080 /// position `b`. Under insert/extract semantics, this is the same as equality 1081 /// of all entries of `a` that are >=0 with the corresponding entries of b. 1082 /// Comparison is on the common prefix (i.e. zip). 1083 template <typename ContainerA, typename ContainerB> 1084 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { 1085 for (auto it : llvm::zip(a, b)) { 1086 if (std::get<0>(it) < 0 || std::get<0>(it) < 0) 1087 continue; 1088 if (std::get<0>(it) != std::get<1>(it)) 1089 return false; 1090 } 1091 return true; 1092 } 1093 1094 /// Folding is only possible in the absence of an internal permutation in the 1095 /// result vector. 1096 bool canFold() { 1097 return (sentinels == 1098 makeArrayRef(extractPosition).drop_front(extractedRank)); 1099 } 1100 1101 // Helper to get the next defining op of interest. 1102 void updateStateForNextIteration(Value v) { 1103 nextInsertOp = v.getDefiningOp<vector::InsertOp>(); 1104 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>(); 1105 }; 1106 1107 // Case 1. If we hit a transpose, just compose the map and iterate. 1108 // Invariant: insert + transpose do not change rank, we can always compose. 1109 LogicalResult handleTransposeOp(); 1110 1111 // Case 2: the insert position matches extractPosition exactly, early return. 1112 LogicalResult handleInsertOpWithMatchingPos(Value &res); 1113 1114 /// Case 3: if the insert position is a prefix of extractPosition, extract a 1115 /// portion of the source of the insert. 1116 /// Example: 1117 /// ``` 1118 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> 1119 /// // extractPosition == [1, 2, 3] 1120 /// %ext = vector.extract %ins[1, 0]: vector<3x4x5> 1121 /// // can fold to vector.extract %source[0, 3] 1122 /// %ext = vector.extract %source[3]: vector<5x6> 1123 /// ``` 1124 /// To traverse through %source, we need to set the leading dims to 0 and 1125 /// drop the extra leading dims. 1126 /// This method updates the internal state. 1127 LogicalResult handleInsertOpWithPrefixPos(Value &res); 1128 1129 /// Try to fold in place to extract(source, extractPosition) and return the 1130 /// folded result. Return null if folding is not possible (e.g. due to an 1131 /// internal tranposition in the result). 1132 Value tryToFoldExtractOpInPlace(Value source); 1133 1134 ExtractOp extractOp; 1135 int64_t vectorRank; 1136 int64_t extractedRank; 1137 1138 InsertOp nextInsertOp; 1139 TransposeOp nextTransposeOp; 1140 1141 /// Sentinel values that encode the internal permutation status of the result. 1142 /// They are set to (-1, ... , -k) at the beginning and appended to 1143 /// `extractPosition`. 1144 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to 1145 /// ensure that there is no internal transposition. 1146 /// Internal transposition cannot be accounted for with a folding pattern. 1147 // TODO: We could relax the internal transposition with an extra transposition 1148 // operation in a future canonicalizer. 1149 SmallVector<int64_t> sentinels; 1150 SmallVector<int64_t> extractPosition; 1151 }; 1152 } // namespace 1153 1154 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( 1155 ExtractOp e) 1156 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), 1157 extractedRank(extractOp.position().size()) { 1158 assert(vectorRank >= extractedRank && "extracted pos overflow"); 1159 sentinels.reserve(vectorRank - extractedRank); 1160 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) 1161 sentinels.push_back(-(i + 1)); 1162 extractPosition = extractVector<int64_t>(extractOp.position()); 1163 llvm::append_range(extractPosition, sentinels); 1164 } 1165 1166 // Case 1. If we hit a transpose, just compose the map and iterate. 1167 // Invariant: insert + transpose do not change rank, we can always compose. 1168 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { 1169 if (!nextTransposeOp) 1170 return failure(); 1171 auto permutation = extractVector<unsigned>(nextTransposeOp.transp()); 1172 AffineMap m = inversePermutation( 1173 AffineMap::getPermutationMap(permutation, extractOp.getContext())); 1174 extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); 1175 return success(); 1176 } 1177 1178 // Case 2: the insert position matches extractPosition exactly, early return. 1179 LogicalResult 1180 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( 1181 Value &res) { 1182 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1183 if (makeArrayRef(insertedPos) != 1184 llvm::makeArrayRef(extractPosition).take_front(extractedRank)) 1185 return failure(); 1186 // Case 2.a. early-exit fold. 1187 res = nextInsertOp.source(); 1188 // Case 2.b. if internal transposition is present, canFold will be false. 1189 return success(); 1190 } 1191 1192 /// Case 3: if inserted position is a prefix of extractPosition, 1193 /// extract a portion of the source of the insertion. 1194 /// This method updates the internal state. 1195 LogicalResult 1196 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { 1197 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1198 if (!isContainedWithin(insertedPos, extractPosition)) 1199 return failure(); 1200 // Set leading dims to zero. 1201 std::fill_n(extractPosition.begin(), insertedPos.size(), 0); 1202 // Drop extra leading dims. 1203 extractPosition.erase(extractPosition.begin(), 1204 extractPosition.begin() + insertedPos.size()); 1205 extractedRank = extractPosition.size() - sentinels.size(); 1206 // Case 3.a. early-exit fold (break and delegate to post-while path). 1207 res = nextInsertOp.source(); 1208 // Case 3.b. if internal transposition is present, canFold will be false. 1209 return success(); 1210 } 1211 1212 /// Try to fold in place to extract(source, extractPosition) and return the 1213 /// folded result. Return null if folding is not possible (e.g. due to an 1214 /// internal tranposition in the result). 1215 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( 1216 Value source) { 1217 // If we can't fold (either internal transposition, or nothing to fold), bail. 1218 bool nothingToFold = (source == extractOp.vector()); 1219 if (nothingToFold || !canFold()) 1220 return Value(); 1221 // Otherwise, fold by updating the op inplace and return its result. 1222 OpBuilder b(extractOp.getContext()); 1223 extractOp->setAttr( 1224 extractOp.positionAttrName(), 1225 b.getI64ArrayAttr( 1226 makeArrayRef(extractPosition).take_front(extractedRank))); 1227 extractOp.vectorMutable().assign(source); 1228 return extractOp.getResult(); 1229 } 1230 1231 /// Iterate over producing insert and transpose ops until we find a fold. 1232 Value ExtractFromInsertTransposeChainState::fold() { 1233 Value valueToExtractFrom = extractOp.vector(); 1234 updateStateForNextIteration(valueToExtractFrom); 1235 while (nextInsertOp || nextTransposeOp) { 1236 // Case 1. If we hit a transpose, just compose the map and iterate. 1237 // Invariant: insert + transpose do not change rank, we can always compose. 1238 if (succeeded(handleTransposeOp())) { 1239 valueToExtractFrom = nextTransposeOp.vector(); 1240 updateStateForNextIteration(valueToExtractFrom); 1241 continue; 1242 } 1243 1244 Value result; 1245 // Case 2: the position match exactly. 1246 if (succeeded(handleInsertOpWithMatchingPos(result))) 1247 return result; 1248 1249 // Case 3: if the inserted position is a prefix of extractPosition, we can 1250 // just extract a portion of the source of the insert. 1251 if (succeeded(handleInsertOpWithPrefixPos(result))) 1252 return tryToFoldExtractOpInPlace(result); 1253 1254 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel 1255 // values. This is a more difficult case and we bail. 1256 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1257 if (isContainedWithin(extractPosition, insertedPos) || 1258 intersectsWhereNonNegative(extractPosition, insertedPos)) 1259 return Value(); 1260 1261 // Case 5: No intersection, we forward the extract to insertOp.dest(). 1262 valueToExtractFrom = nextInsertOp.dest(); 1263 updateStateForNextIteration(valueToExtractFrom); 1264 } 1265 // If after all this we can fold, go for it. 1266 return tryToFoldExtractOpInPlace(valueToExtractFrom); 1267 } 1268 1269 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. 1270 static Value foldExtractFromBroadcast(ExtractOp extractOp) { 1271 Operation *defOp = extractOp.vector().getDefiningOp(); 1272 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1273 return Value(); 1274 Value source = defOp->getOperand(0); 1275 if (extractOp.getType() == source.getType()) 1276 return source; 1277 auto getRank = [](Type type) { 1278 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1279 }; 1280 unsigned broadcastSrcRank = getRank(source.getType()); 1281 unsigned extractResultRank = getRank(extractOp.getType()); 1282 if (extractResultRank < broadcastSrcRank) { 1283 auto extractPos = extractVector<int64_t>(extractOp.position()); 1284 unsigned rankDiff = broadcastSrcRank - extractResultRank; 1285 extractPos.erase( 1286 extractPos.begin(), 1287 std::next(extractPos.begin(), extractPos.size() - rankDiff)); 1288 extractOp.setOperand(source); 1289 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1290 OpBuilder b(extractOp.getContext()); 1291 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1292 b.getI64ArrayAttr(extractPos)); 1293 return extractOp.getResult(); 1294 } 1295 return Value(); 1296 } 1297 1298 // Fold extractOp with source coming from ShapeCast op. 1299 static Value foldExtractFromShapeCast(ExtractOp extractOp) { 1300 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); 1301 if (!shapeCastOp) 1302 return Value(); 1303 // Get the nth dimension size starting from lowest dimension. 1304 auto getDimReverse = [](VectorType type, int64_t n) { 1305 return type.getShape().take_back(n + 1).front(); 1306 }; 1307 int64_t destinationRank = 1308 extractOp.getType().isa<VectorType>() 1309 ? extractOp.getType().cast<VectorType>().getRank() 1310 : 0; 1311 if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) 1312 return Value(); 1313 if (destinationRank > 0) { 1314 auto destinationType = extractOp.getResult().getType().cast<VectorType>(); 1315 for (int64_t i = 0; i < destinationRank; i++) { 1316 // The lowest dimension of of the destination must match the lowest 1317 // dimension of the shapecast op source. 1318 // TODO: This case could be support in a canonicalization pattern. 1319 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != 1320 getDimReverse(destinationType, i)) 1321 return Value(); 1322 } 1323 } 1324 // Extract the strides associated with the extract op vector source. Then use 1325 // this to calculate a linearized position for the extract. 1326 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1327 std::reverse(extractedPos.begin(), extractedPos.end()); 1328 SmallVector<int64_t, 4> strides; 1329 int64_t stride = 1; 1330 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { 1331 strides.push_back(stride); 1332 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); 1333 } 1334 1335 int64_t position = linearize(extractedPos, strides); 1336 // Then extract the strides associated to the shapeCast op vector source and 1337 // delinearize the position using those strides. 1338 SmallVector<int64_t, 4> newStrides; 1339 int64_t numDimension = 1340 shapeCastOp.getSourceVectorType().getRank() - destinationRank; 1341 stride = 1; 1342 for (int64_t i = 0; i < numDimension; i++) { 1343 newStrides.push_back(stride); 1344 stride *= 1345 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); 1346 } 1347 std::reverse(newStrides.begin(), newStrides.end()); 1348 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position); 1349 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1350 OpBuilder b(extractOp.getContext()); 1351 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1352 b.getI64ArrayAttr(newPosition)); 1353 extractOp.setOperand(shapeCastOp.source()); 1354 return extractOp.getResult(); 1355 } 1356 1357 /// Fold an ExtractOp from ExtractStridedSliceOp. 1358 static Value foldExtractFromExtractStrided(ExtractOp extractOp) { 1359 auto extractStridedSliceOp = 1360 extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>(); 1361 if (!extractStridedSliceOp) 1362 return Value(); 1363 // Return if 'extractStridedSliceOp' has non-unit strides. 1364 if (extractStridedSliceOp.hasNonUnitStrides()) 1365 return Value(); 1366 1367 // Trim offsets for dimensions fully extracted. 1368 auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets()); 1369 while (!sliceOffsets.empty()) { 1370 size_t lastOffset = sliceOffsets.size() - 1; 1371 if (sliceOffsets.back() != 0 || 1372 extractStridedSliceOp.getType().getDimSize(lastOffset) != 1373 extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) 1374 break; 1375 sliceOffsets.pop_back(); 1376 } 1377 unsigned destinationRank = 0; 1378 if (auto vecType = extractOp.getType().dyn_cast<VectorType>()) 1379 destinationRank = vecType.getRank(); 1380 // The dimensions of the result need to be untouched by the 1381 // extractStridedSlice op. 1382 if (destinationRank > 1383 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) 1384 return Value(); 1385 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1386 assert(extractedPos.size() >= sliceOffsets.size()); 1387 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) 1388 extractedPos[i] = extractedPos[i] + sliceOffsets[i]; 1389 extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); 1390 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1391 OpBuilder b(extractOp.getContext()); 1392 extractOp->setAttr(ExtractOp::getPositionAttrStrName(), 1393 b.getI64ArrayAttr(extractedPos)); 1394 return extractOp.getResult(); 1395 } 1396 1397 /// Fold extract_op fed from a chain of insertStridedSlice ops. 1398 static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { 1399 int64_t destinationRank = op.getType().isa<VectorType>() 1400 ? op.getType().cast<VectorType>().getRank() 1401 : 0; 1402 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 1403 while (insertOp) { 1404 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - 1405 insertOp.getSourceVectorType().getRank(); 1406 if (destinationRank > insertOp.getSourceVectorType().getRank()) 1407 return Value(); 1408 auto insertOffsets = extractVector<int64_t>(insertOp.offsets()); 1409 auto extractOffsets = extractVector<int64_t>(op.position()); 1410 1411 if (llvm::any_of(insertOp.strides(), [](Attribute attr) { 1412 return attr.cast<IntegerAttr>().getInt() != 1; 1413 })) 1414 return Value(); 1415 bool disjoint = false; 1416 SmallVector<int64_t, 4> offsetDiffs; 1417 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 1418 int64_t start = insertOffsets[dim]; 1419 int64_t size = 1420 (dim < insertRankDiff) 1421 ? 1 1422 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); 1423 int64_t end = start + size; 1424 int64_t offset = extractOffsets[dim]; 1425 // Check if the start of the extract offset is in the interval inserted. 1426 if (start <= offset && offset < end) { 1427 if (dim >= insertRankDiff) 1428 offsetDiffs.push_back(offset - start); 1429 continue; 1430 } 1431 disjoint = true; 1432 break; 1433 } 1434 // The extract element chunk overlap with the vector inserted. 1435 if (!disjoint) { 1436 // If any of the inner dimensions are only partially inserted we have a 1437 // partial overlap. 1438 int64_t srcRankDiff = 1439 insertOp.getSourceVectorType().getRank() - destinationRank; 1440 for (int64_t i = 0; i < destinationRank; i++) { 1441 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != 1442 insertOp.getDestVectorType().getDimSize(i + srcRankDiff + 1443 insertRankDiff)) 1444 return Value(); 1445 } 1446 op.vectorMutable().assign(insertOp.source()); 1447 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1448 OpBuilder b(op.getContext()); 1449 op->setAttr(ExtractOp::getPositionAttrStrName(), 1450 b.getI64ArrayAttr(offsetDiffs)); 1451 return op.getResult(); 1452 } 1453 // If the chunk extracted is disjoint from the chunk inserted, keep 1454 // looking in the insert chain. 1455 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 1456 } 1457 return Value(); 1458 } 1459 1460 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { 1461 if (position().empty()) 1462 return vector(); 1463 if (succeeded(foldExtractOpFromExtractChain(*this))) 1464 return getResult(); 1465 if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) 1466 return res; 1467 if (auto res = foldExtractFromBroadcast(*this)) 1468 return res; 1469 if (auto res = foldExtractFromShapeCast(*this)) 1470 return res; 1471 if (auto val = foldExtractFromExtractStrided(*this)) 1472 return val; 1473 if (auto val = foldExtractStridedOpFromInsertChain(*this)) 1474 return val; 1475 return OpFoldResult(); 1476 } 1477 1478 namespace { 1479 1480 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. 1481 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { 1482 public: 1483 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1484 1485 LogicalResult matchAndRewrite(ExtractOp extractOp, 1486 PatternRewriter &rewriter) const override { 1487 Operation *defOp = extractOp.vector().getDefiningOp(); 1488 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1489 return failure(); 1490 Value source = defOp->getOperand(0); 1491 if (extractOp.getType() == source.getType()) 1492 return failure(); 1493 auto getRank = [](Type type) { 1494 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1495 }; 1496 unsigned broadcastSrcRank = getRank(source.getType()); 1497 unsigned extractResultRank = getRank(extractOp.getType()); 1498 // We only consider the case where the rank of the source is smaller than 1499 // the rank of the extract dst. The other cases are handled in the folding 1500 // patterns. 1501 if (extractResultRank <= broadcastSrcRank) 1502 return failure(); 1503 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1504 extractOp, extractOp.getType(), source); 1505 return success(); 1506 } 1507 }; 1508 1509 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. 1510 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> { 1511 public: 1512 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1513 1514 LogicalResult matchAndRewrite(ExtractOp extractOp, 1515 PatternRewriter &rewriter) const override { 1516 // Return if 'extractStridedSliceOp' operand is not defined by a 1517 // ConstantOp. 1518 auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>(); 1519 if (!constantOp) 1520 return failure(); 1521 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 1522 if (!dense) 1523 return failure(); 1524 Attribute newAttr = dense.getSplatValue<Attribute>(); 1525 if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>()) 1526 newAttr = DenseElementsAttr::get(vecDstType, newAttr); 1527 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); 1528 return success(); 1529 } 1530 }; 1531 1532 } // namespace 1533 1534 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, 1535 MLIRContext *context) { 1536 results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context); 1537 } 1538 1539 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 1540 SmallVectorImpl<int64_t> &results) { 1541 for (auto attr : arrayAttr) 1542 results.push_back(attr.cast<IntegerAttr>().getInt()); 1543 } 1544 1545 //===----------------------------------------------------------------------===// 1546 // ExtractMapOp 1547 //===----------------------------------------------------------------------===// 1548 1549 void ExtractMapOp::build(OpBuilder &builder, OperationState &result, 1550 Value vector, ValueRange ids, 1551 ArrayRef<int64_t> multiplicity, 1552 AffineMap permutationMap) { 1553 assert(ids.size() == multiplicity.size() && 1554 ids.size() == permutationMap.getNumResults()); 1555 assert(permutationMap.isProjectedPermutation()); 1556 VectorType type = vector.getType().cast<VectorType>(); 1557 SmallVector<int64_t, 4> newShape(type.getShape().begin(), 1558 type.getShape().end()); 1559 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) { 1560 AffineExpr expr = permutationMap.getResult(i); 1561 auto dim = expr.cast<AffineDimExpr>(); 1562 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i]; 1563 } 1564 VectorType resultType = VectorType::get(newShape, type.getElementType()); 1565 ExtractMapOp::build(builder, result, resultType, vector, ids); 1566 } 1567 1568 LogicalResult ExtractMapOp::verify() { 1569 if (getSourceVectorType().getRank() != getResultType().getRank()) 1570 return emitOpError("expected source and destination vectors of same rank"); 1571 unsigned numId = 0; 1572 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) { 1573 if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) != 1574 0) 1575 return emitOpError("source vector dimensions must be a multiple of " 1576 "destination vector dimensions"); 1577 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1578 numId++; 1579 } 1580 if (numId != ids().size()) 1581 return emitOpError("expected number of ids must match the number of " 1582 "dimensions distributed"); 1583 return success(); 1584 } 1585 1586 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) { 1587 auto insert = vector().getDefiningOp<vector::InsertMapOp>(); 1588 if (insert == nullptr || getType() != insert.vector().getType() || 1589 ids() != insert.ids()) 1590 return {}; 1591 return insert.vector(); 1592 } 1593 1594 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) { 1595 assert(multiplicity.empty()); 1596 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) { 1597 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1598 multiplicity.push_back(getSourceVectorType().getDimSize(i) / 1599 getResultType().getDimSize(i)); 1600 } 1601 } 1602 1603 template <typename MapOp> 1604 AffineMap calculateImplicitMap(MapOp op) { 1605 SmallVector<AffineExpr, 4> perm; 1606 // Check which dimension have a multiplicity greater than 1 and associated 1607 // them to the IDs in order. 1608 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) { 1609 if (op.getSourceVectorType().getDimSize(i) != 1610 op.getResultType().getDimSize(i)) 1611 perm.push_back(getAffineDimExpr(i, op.getContext())); 1612 } 1613 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm, 1614 op.getContext()); 1615 return map; 1616 } 1617 1618 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } 1619 1620 //===----------------------------------------------------------------------===// 1621 // FmaOp 1622 //===----------------------------------------------------------------------===// 1623 1624 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { 1625 return llvm::to_vector<4>(getVectorType().getShape()); 1626 } 1627 1628 //===----------------------------------------------------------------------===// 1629 // BroadcastOp 1630 //===----------------------------------------------------------------------===// 1631 1632 BroadcastableToResult 1633 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, 1634 std::pair<int, int> *mismatchingDims) { 1635 // Broadcast scalar to vector of the same element type. 1636 if (srcType.isIntOrIndexOrFloat() && dstVectorType && 1637 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) 1638 return BroadcastableToResult::Success; 1639 // From now on, only vectors broadcast. 1640 VectorType srcVectorType = srcType.dyn_cast<VectorType>(); 1641 if (!srcVectorType) 1642 return BroadcastableToResult::SourceTypeNotAVector; 1643 1644 int64_t srcRank = srcVectorType.getRank(); 1645 int64_t dstRank = dstVectorType.getRank(); 1646 if (srcRank > dstRank) 1647 return BroadcastableToResult::SourceRankHigher; 1648 // Source has an exact match or singleton value for all trailing dimensions 1649 // (all leading dimensions are simply duplicated). 1650 int64_t lead = dstRank - srcRank; 1651 for (int64_t r = 0; r < srcRank; ++r) { 1652 int64_t srcDim = srcVectorType.getDimSize(r); 1653 int64_t dstDim = dstVectorType.getDimSize(lead + r); 1654 if (srcDim != 1 && srcDim != dstDim) { 1655 if (mismatchingDims) { 1656 mismatchingDims->first = srcDim; 1657 mismatchingDims->second = dstDim; 1658 } 1659 return BroadcastableToResult::DimensionMismatch; 1660 } 1661 } 1662 1663 return BroadcastableToResult::Success; 1664 } 1665 1666 LogicalResult BroadcastOp::verify() { 1667 std::pair<int, int> mismatchingDims; 1668 BroadcastableToResult res = 1669 isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); 1670 if (res == BroadcastableToResult::Success) 1671 return success(); 1672 if (res == BroadcastableToResult::SourceRankHigher) 1673 return emitOpError("source rank higher than destination rank"); 1674 if (res == BroadcastableToResult::DimensionMismatch) 1675 return emitOpError("dimension mismatch (") 1676 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; 1677 if (res == BroadcastableToResult::SourceTypeNotAVector) 1678 return emitOpError("source type is not a vector"); 1679 llvm_unreachable("unexpected vector.broadcast op error"); 1680 } 1681 1682 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 1683 if (getSourceType() == getVectorType()) 1684 return source(); 1685 if (!operands[0]) 1686 return {}; 1687 auto vectorType = getVectorType(); 1688 if (operands[0].getType().isIntOrIndexOrFloat()) 1689 return DenseElementsAttr::get(vectorType, operands[0]); 1690 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>()) 1691 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()); 1692 return {}; 1693 } 1694 1695 namespace { 1696 1697 // Fold broadcast1(broadcast2(x)) into broadcast1(x). 1698 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { 1699 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 1700 1701 LogicalResult matchAndRewrite(BroadcastOp broadcastOp, 1702 PatternRewriter &rewriter) const override { 1703 auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>(); 1704 if (!srcBroadcast) 1705 return failure(); 1706 rewriter.replaceOpWithNewOp<BroadcastOp>( 1707 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source()); 1708 return success(); 1709 } 1710 }; 1711 } // namespace 1712 1713 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 1714 MLIRContext *context) { 1715 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by 1716 // calling `populateCastAwayVectorLeadingOneDimPatterns` 1717 results.add<BroadcastFolder>(context); 1718 } 1719 1720 //===----------------------------------------------------------------------===// 1721 // ShuffleOp 1722 //===----------------------------------------------------------------------===// 1723 1724 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, 1725 Value v2, ArrayRef<int64_t> mask) { 1726 result.addOperands({v1, v2}); 1727 auto maskAttr = getVectorSubscriptAttr(builder, mask); 1728 auto v1Type = v1.getType().cast<VectorType>(); 1729 auto shape = llvm::to_vector<4>(v1Type.getShape()); 1730 shape[0] = mask.size(); 1731 result.addTypes(VectorType::get(shape, v1Type.getElementType())); 1732 result.addAttribute(getMaskAttrStrName(), maskAttr); 1733 } 1734 1735 void ShuffleOp::print(OpAsmPrinter &p) { 1736 p << " " << v1() << ", " << v2() << " " << mask(); 1737 p.printOptionalAttrDict((*this)->getAttrs(), {ShuffleOp::getMaskAttrName()}); 1738 p << " : " << v1().getType() << ", " << v2().getType(); 1739 } 1740 1741 LogicalResult ShuffleOp::verify() { 1742 VectorType resultType = getVectorType(); 1743 VectorType v1Type = getV1VectorType(); 1744 VectorType v2Type = getV2VectorType(); 1745 // Verify ranks. 1746 int64_t resRank = resultType.getRank(); 1747 int64_t v1Rank = v1Type.getRank(); 1748 int64_t v2Rank = v2Type.getRank(); 1749 if (resRank != v1Rank || v1Rank != v2Rank) 1750 return emitOpError("rank mismatch"); 1751 // Verify all but leading dimension sizes. 1752 for (int64_t r = 1; r < v1Rank; ++r) { 1753 int64_t resDim = resultType.getDimSize(r); 1754 int64_t v1Dim = v1Type.getDimSize(r); 1755 int64_t v2Dim = v2Type.getDimSize(r); 1756 if (resDim != v1Dim || v1Dim != v2Dim) 1757 return emitOpError("dimension mismatch"); 1758 } 1759 // Verify mask length. 1760 auto maskAttr = mask().getValue(); 1761 int64_t maskLength = maskAttr.size(); 1762 if (maskLength != resultType.getDimSize(0)) 1763 return emitOpError("mask length mismatch"); 1764 // Verify all indices. 1765 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); 1766 for (const auto &en : llvm::enumerate(maskAttr)) { 1767 auto attr = en.value().dyn_cast<IntegerAttr>(); 1768 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) 1769 return emitOpError("mask index #") << (en.index() + 1) << " out of range"; 1770 } 1771 return success(); 1772 } 1773 1774 ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { 1775 OpAsmParser::OperandType v1, v2; 1776 Attribute attr; 1777 VectorType v1Type, v2Type; 1778 if (parser.parseOperand(v1) || parser.parseComma() || 1779 parser.parseOperand(v2) || 1780 parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(), 1781 result.attributes) || 1782 parser.parseOptionalAttrDict(result.attributes) || 1783 parser.parseColonType(v1Type) || parser.parseComma() || 1784 parser.parseType(v2Type) || 1785 parser.resolveOperand(v1, v1Type, result.operands) || 1786 parser.resolveOperand(v2, v2Type, result.operands)) 1787 return failure(); 1788 // Construct resulting type: leading dimension matches mask length, 1789 // all trailing dimensions match the operands. 1790 auto maskAttr = attr.dyn_cast<ArrayAttr>(); 1791 if (!maskAttr) 1792 return parser.emitError(parser.getNameLoc(), "missing mask attribute"); 1793 int64_t maskLength = maskAttr.size(); 1794 if (maskLength <= 0) 1795 return parser.emitError(parser.getNameLoc(), "invalid mask length"); 1796 int64_t v1Rank = v1Type.getRank(); 1797 SmallVector<int64_t, 4> shape; 1798 shape.reserve(v1Rank); 1799 shape.push_back(maskLength); 1800 for (int64_t r = 1; r < v1Rank; ++r) 1801 shape.push_back(v1Type.getDimSize(r)); 1802 VectorType resType = VectorType::get(shape, v1Type.getElementType()); 1803 parser.addTypeToList(resType, result.types); 1804 return success(); 1805 } 1806 1807 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) { 1808 Attribute lhs = operands.front(), rhs = operands.back(); 1809 if (!lhs || !rhs) 1810 return {}; 1811 1812 auto lhsType = lhs.getType().cast<VectorType>(); 1813 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr 1814 // manipulation. 1815 if (lhsType.getRank() != 1) 1816 return {}; 1817 int64_t lhsSize = lhsType.getDimSize(0); 1818 1819 SmallVector<Attribute> results; 1820 auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>(); 1821 auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>(); 1822 for (const auto &index : this->mask().getAsValueRange<IntegerAttr>()) { 1823 int64_t i = index.getZExtValue(); 1824 if (i >= lhsSize) { 1825 results.push_back(rhsElements[i - lhsSize]); 1826 } else { 1827 results.push_back(lhsElements[i]); 1828 } 1829 } 1830 1831 return DenseElementsAttr::get(getVectorType(), results); 1832 } 1833 1834 //===----------------------------------------------------------------------===// 1835 // InsertElementOp 1836 //===----------------------------------------------------------------------===// 1837 1838 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1839 Value source, Value dest) { 1840 result.addOperands({source, dest}); 1841 result.addTypes(dest.getType()); 1842 } 1843 1844 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1845 Value source, Value dest, Value position) { 1846 result.addOperands({source, dest, position}); 1847 result.addTypes(dest.getType()); 1848 } 1849 1850 LogicalResult InsertElementOp::verify() { 1851 auto dstVectorType = getDestVectorType(); 1852 if (dstVectorType.getRank() == 0) { 1853 if (position()) 1854 return emitOpError("expected position to be empty with 0-D vector"); 1855 return success(); 1856 } 1857 if (dstVectorType.getRank() != 1) 1858 return emitOpError("unexpected >1 vector rank"); 1859 if (!position()) 1860 return emitOpError("expected position for 1-D vector"); 1861 return success(); 1862 } 1863 1864 //===----------------------------------------------------------------------===// 1865 // InsertOp 1866 //===----------------------------------------------------------------------===// 1867 1868 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1869 Value dest, ArrayRef<int64_t> position) { 1870 result.addOperands({source, dest}); 1871 auto positionAttr = getVectorSubscriptAttr(builder, position); 1872 result.addTypes(dest.getType()); 1873 result.addAttribute(getPositionAttrStrName(), positionAttr); 1874 } 1875 1876 // Convenience builder which assumes the values are constant indices. 1877 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1878 Value dest, ValueRange position) { 1879 SmallVector<int64_t, 4> positionConstants = 1880 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 1881 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 1882 })); 1883 build(builder, result, source, dest, positionConstants); 1884 } 1885 1886 LogicalResult InsertOp::verify() { 1887 auto positionAttr = position().getValue(); 1888 auto destVectorType = getDestVectorType(); 1889 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank())) 1890 return emitOpError( 1891 "expected position attribute of rank smaller than dest vector rank"); 1892 auto srcVectorType = getSourceType().dyn_cast<VectorType>(); 1893 if (srcVectorType && 1894 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() != 1895 static_cast<unsigned>(destVectorType.getRank()))) 1896 return emitOpError("expected position attribute rank + source rank to " 1897 "match dest vector rank"); 1898 if (!srcVectorType && 1899 (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) 1900 return emitOpError( 1901 "expected position attribute rank to match the dest vector rank"); 1902 for (const auto &en : llvm::enumerate(positionAttr)) { 1903 auto attr = en.value().dyn_cast<IntegerAttr>(); 1904 if (!attr || attr.getInt() < 0 || 1905 attr.getInt() >= destVectorType.getDimSize(en.index())) 1906 return emitOpError("expected position attribute #") 1907 << (en.index() + 1) 1908 << " to be a non-negative integer smaller than the corresponding " 1909 "dest vector dimension"; 1910 } 1911 return success(); 1912 } 1913 1914 namespace { 1915 1916 // If insertOp is only inserting unit dimensions it can be transformed to a 1917 // broadcast. 1918 class InsertToBroadcast final : public OpRewritePattern<InsertOp> { 1919 public: 1920 using OpRewritePattern<InsertOp>::OpRewritePattern; 1921 1922 LogicalResult matchAndRewrite(InsertOp insertOp, 1923 PatternRewriter &rewriter) const override { 1924 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>(); 1925 if (!srcVecType || insertOp.getDestVectorType().getNumElements() != 1926 srcVecType.getNumElements()) 1927 return failure(); 1928 rewriter.replaceOpWithNewOp<BroadcastOp>( 1929 insertOp, insertOp.getDestVectorType(), insertOp.source()); 1930 return success(); 1931 } 1932 }; 1933 1934 } // namespace 1935 1936 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, 1937 MLIRContext *context) { 1938 results.add<InsertToBroadcast, BroadcastFolder>(context); 1939 } 1940 1941 // Eliminates insert operations that produce values identical to their source 1942 // value. This happens when the source and destination vectors have identical 1943 // sizes. 1944 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) { 1945 if (position().empty()) 1946 return source(); 1947 return {}; 1948 } 1949 1950 //===----------------------------------------------------------------------===// 1951 // InsertMapOp 1952 //===----------------------------------------------------------------------===// 1953 1954 void InsertMapOp::build(OpBuilder &builder, OperationState &result, 1955 Value vector, Value dest, ValueRange ids) { 1956 InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); 1957 } 1958 1959 LogicalResult InsertMapOp::verify() { 1960 if (getSourceVectorType().getRank() != getResultType().getRank()) 1961 return emitOpError("expected source and destination vectors of same rank"); 1962 unsigned numId = 0; 1963 for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) { 1964 if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) != 1965 0) 1966 return emitOpError( 1967 "destination vector size must be a multiple of source vector size"); 1968 if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i)) 1969 numId++; 1970 } 1971 if (numId != ids().size()) 1972 return emitOpError("expected number of ids must match the number of " 1973 "dimensions distributed"); 1974 return success(); 1975 } 1976 1977 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); } 1978 1979 //===----------------------------------------------------------------------===// 1980 // InsertStridedSliceOp 1981 //===----------------------------------------------------------------------===// 1982 1983 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, 1984 Value source, Value dest, 1985 ArrayRef<int64_t> offsets, 1986 ArrayRef<int64_t> strides) { 1987 result.addOperands({source, dest}); 1988 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 1989 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 1990 result.addTypes(dest.getType()); 1991 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); 1992 result.addAttribute(getStridesAttrStrName(), stridesAttr); 1993 } 1994 1995 // TODO: Should be moved to Tablegen Confined attributes. 1996 template <typename OpType> 1997 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, 1998 ArrayAttr arrayAttr, 1999 ArrayRef<int64_t> shape, 2000 StringRef attrName) { 2001 if (arrayAttr.size() > shape.size()) 2002 return op.emitOpError("expected ") 2003 << attrName << " attribute of rank smaller than vector rank"; 2004 return success(); 2005 } 2006 2007 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 2008 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2009 // Otherwise, the admissible interval is [min, max]. 2010 template <typename OpType> 2011 static LogicalResult 2012 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, 2013 int64_t max, StringRef attrName, 2014 bool halfOpen = true) { 2015 for (auto attr : arrayAttr) { 2016 auto val = attr.cast<IntegerAttr>().getInt(); 2017 auto upper = max; 2018 if (!halfOpen) 2019 upper += 1; 2020 if (val < min || val >= upper) 2021 return op.emitOpError("expected ") << attrName << " to be confined to [" 2022 << min << ", " << upper << ")"; 2023 } 2024 return success(); 2025 } 2026 2027 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 2028 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2029 // Otherwise, the admissible interval is [min, max]. 2030 template <typename OpType> 2031 static LogicalResult 2032 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, 2033 ArrayRef<int64_t> shape, StringRef attrName, 2034 bool halfOpen = true, int64_t min = 0) { 2035 assert(arrayAttr.size() <= shape.size()); 2036 unsigned index = 0; 2037 for (auto it : llvm::zip(arrayAttr, shape)) { 2038 auto val = std::get<0>(it).cast<IntegerAttr>().getInt(); 2039 auto max = std::get<1>(it); 2040 if (!halfOpen) 2041 max += 1; 2042 if (val < min || val >= max) 2043 return op.emitOpError("expected ") 2044 << attrName << " dimension " << index << " to be confined to [" 2045 << min << ", " << max << ")"; 2046 ++index; 2047 } 2048 return success(); 2049 } 2050 2051 // Returns true if all integers in `arrayAttr` are in the interval [min, max}. 2052 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2053 // Otherwise, the admissible interval is [min, max]. 2054 template <typename OpType> 2055 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( 2056 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, 2057 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, 2058 bool halfOpen = true, int64_t min = 1) { 2059 assert(arrayAttr1.size() <= shape.size()); 2060 assert(arrayAttr2.size() <= shape.size()); 2061 unsigned index = 0; 2062 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) { 2063 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt(); 2064 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt(); 2065 auto max = std::get<2>(it); 2066 if (!halfOpen) 2067 max += 1; 2068 if (val1 + val2 < 0 || val1 + val2 >= max) 2069 return op.emitOpError("expected sum(") 2070 << attrName1 << ", " << attrName2 << ") dimension " << index 2071 << " to be confined to [" << min << ", " << max << ")"; 2072 ++index; 2073 } 2074 return success(); 2075 } 2076 2077 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, 2078 MLIRContext *context) { 2079 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { 2080 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); 2081 }); 2082 return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); 2083 } 2084 2085 LogicalResult InsertStridedSliceOp::verify() { 2086 auto sourceVectorType = getSourceVectorType(); 2087 auto destVectorType = getDestVectorType(); 2088 auto offsets = offsetsAttr(); 2089 auto strides = stridesAttr(); 2090 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) 2091 return emitOpError( 2092 "expected offsets of same size as destination vector rank"); 2093 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) 2094 return emitOpError("expected strides of same size as source vector rank"); 2095 if (sourceVectorType.getRank() > destVectorType.getRank()) 2096 return emitOpError( 2097 "expected source rank to be smaller than destination rank"); 2098 2099 auto sourceShape = sourceVectorType.getShape(); 2100 auto destShape = destVectorType.getShape(); 2101 SmallVector<int64_t, 4> sourceShapeAsDestShape( 2102 destShape.size() - sourceShape.size(), 0); 2103 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); 2104 auto offName = InsertStridedSliceOp::getOffsetsAttrName(); 2105 auto stridesName = InsertStridedSliceOp::getStridesAttrName(); 2106 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, 2107 offName)) || 2108 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, 2109 stridesName, 2110 /*halfOpen=*/false)) || 2111 failed(isSumOfIntegerArrayAttrConfinedToShape( 2112 *this, offsets, 2113 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, 2114 offName, "source vector shape", 2115 /*halfOpen=*/false, /*min=*/1))) 2116 return failure(); 2117 2118 return success(); 2119 } 2120 2121 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2122 if (getSourceVectorType() == getDestVectorType()) 2123 return source(); 2124 return {}; 2125 } 2126 2127 //===----------------------------------------------------------------------===// 2128 // OuterProductOp 2129 //===----------------------------------------------------------------------===// 2130 2131 /// Build an op without mask, use the type of `acc` as the return type. 2132 void OuterProductOp::build(OpBuilder &builder, OperationState &result, 2133 Value lhs, Value rhs, Value acc) { 2134 result.addOperands({lhs, rhs, acc}); 2135 result.addTypes(acc.getType()); 2136 } 2137 2138 void OuterProductOp::print(OpAsmPrinter &p) { 2139 p << " " << lhs() << ", " << rhs(); 2140 if (!acc().empty()) { 2141 p << ", " << acc(); 2142 p.printOptionalAttrDict((*this)->getAttrs()); 2143 } 2144 p << " : " << lhs().getType() << ", " << rhs().getType(); 2145 } 2146 2147 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { 2148 SmallVector<OpAsmParser::OperandType, 3> operandsInfo; 2149 Type tLHS, tRHS; 2150 if (parser.parseOperandList(operandsInfo) || 2151 parser.parseOptionalAttrDict(result.attributes) || 2152 parser.parseColonType(tLHS) || parser.parseComma() || 2153 parser.parseType(tRHS)) 2154 return failure(); 2155 if (operandsInfo.size() < 2) 2156 return parser.emitError(parser.getNameLoc(), 2157 "expected at least 2 operands"); 2158 VectorType vLHS = tLHS.dyn_cast<VectorType>(); 2159 VectorType vRHS = tRHS.dyn_cast<VectorType>(); 2160 if (!vLHS) 2161 return parser.emitError(parser.getNameLoc(), 2162 "expected vector type for operand #1"); 2163 VectorType resType = 2164 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, 2165 vLHS.getElementType()) 2166 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); 2167 2168 if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { 2169 result.attributes.append( 2170 OuterProductOp::getKindAttrStrName(), 2171 CombiningKindAttr::get(OuterProductOp::getDefaultKind(), 2172 result.getContext())); 2173 } 2174 2175 return failure( 2176 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || 2177 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || 2178 (operandsInfo.size() > 2 && 2179 parser.resolveOperand(operandsInfo[2], resType, result.operands)) || 2180 parser.addTypeToList(resType, result.types)); 2181 } 2182 2183 LogicalResult OuterProductOp::verify() { 2184 Type tRHS = getOperandTypeRHS(); 2185 VectorType vLHS = getOperandVectorTypeLHS(), 2186 vRHS = tRHS.dyn_cast<VectorType>(), 2187 vACC = getOperandVectorTypeACC(), vRES = getVectorType(); 2188 2189 if (vLHS.getRank() != 1) 2190 return emitOpError("expected 1-d vector for operand #1"); 2191 2192 if (vRHS) { 2193 // Proper OUTER operation. 2194 if (vRHS.getRank() != 1) 2195 return emitOpError("expected 1-d vector for operand #2"); 2196 if (vRES.getRank() != 2) 2197 return emitOpError("expected 2-d vector result"); 2198 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2199 return emitOpError("expected #1 operand dim to match result dim #1"); 2200 if (vRHS.getDimSize(0) != vRES.getDimSize(1)) 2201 return emitOpError("expected #2 operand dim to match result dim #2"); 2202 } else { 2203 // An AXPY operation. 2204 if (vRES.getRank() != 1) 2205 return emitOpError("expected 1-d vector result"); 2206 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2207 return emitOpError("expected #1 operand dim to match result dim #1"); 2208 } 2209 2210 if (vACC && vACC != vRES) 2211 return emitOpError("expected operand #3 of same type as result type"); 2212 2213 // Verify supported combining kind. 2214 if (!isSupportedCombiningKind(kind(), vRES.getElementType())) 2215 return emitOpError("unsupported outerproduct type"); 2216 2217 return success(); 2218 } 2219 2220 //===----------------------------------------------------------------------===// 2221 // ReshapeOp 2222 //===----------------------------------------------------------------------===// 2223 2224 LogicalResult ReshapeOp::verify() { 2225 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. 2226 auto inputVectorType = getInputVectorType(); 2227 auto outputVectorType = getOutputVectorType(); 2228 int64_t inputShapeRank = getNumInputShapeSizes(); 2229 int64_t outputShapeRank = getNumOutputShapeSizes(); 2230 SmallVector<int64_t, 4> fixedVectorSizes; 2231 getFixedVectorSizes(fixedVectorSizes); 2232 int64_t numFixedVectorSizes = fixedVectorSizes.size(); 2233 2234 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) 2235 return emitError("invalid input shape for vector type ") 2236 << inputVectorType; 2237 2238 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) 2239 return emitError("invalid output shape for vector type ") 2240 << outputVectorType; 2241 2242 // Verify that the 'fixedVectorSizes' match an input/output vector shape 2243 // suffix. 2244 unsigned inputVectorRank = inputVectorType.getRank(); 2245 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2246 unsigned index = inputVectorRank - numFixedVectorSizes - i; 2247 if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) 2248 return emitError("fixed vector size must match input vector for dim ") 2249 << i; 2250 } 2251 2252 unsigned outputVectorRank = outputVectorType.getRank(); 2253 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2254 unsigned index = outputVectorRank - numFixedVectorSizes - i; 2255 if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) 2256 return emitError("fixed vector size must match output vector for dim ") 2257 << i; 2258 } 2259 2260 // If all shape operands are produced by constant ops, verify that product 2261 // of dimensions for input/output shape match. 2262 auto isDefByConstant = [](Value operand) { 2263 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 2264 }; 2265 if (llvm::all_of(input_shape(), isDefByConstant) && 2266 llvm::all_of(output_shape(), isDefByConstant)) { 2267 int64_t numInputElements = 1; 2268 for (auto operand : input_shape()) 2269 numInputElements *= 2270 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2271 int64_t numOutputElements = 1; 2272 for (auto operand : output_shape()) 2273 numOutputElements *= 2274 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2275 if (numInputElements != numOutputElements) 2276 return emitError("product of input and output shape sizes must match"); 2277 } 2278 return success(); 2279 } 2280 2281 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) { 2282 populateFromInt64AttrArray(fixed_vector_sizes(), results); 2283 } 2284 2285 //===----------------------------------------------------------------------===// 2286 // ExtractStridedSliceOp 2287 //===----------------------------------------------------------------------===// 2288 2289 // Inference works as follows: 2290 // 1. Add 'sizes' from prefix of dims in 'offsets'. 2291 // 2. Add sizes from 'vectorType' for remaining dims. 2292 static Type inferStridedSliceOpResultType(VectorType vectorType, 2293 ArrayAttr offsets, ArrayAttr sizes, 2294 ArrayAttr strides) { 2295 assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); 2296 SmallVector<int64_t, 4> shape; 2297 shape.reserve(vectorType.getRank()); 2298 unsigned idx = 0; 2299 for (unsigned e = offsets.size(); idx < e; ++idx) 2300 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt()); 2301 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) 2302 shape.push_back(vectorType.getShape()[idx]); 2303 2304 return VectorType::get(shape, vectorType.getElementType()); 2305 } 2306 2307 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, 2308 Value source, ArrayRef<int64_t> offsets, 2309 ArrayRef<int64_t> sizes, 2310 ArrayRef<int64_t> strides) { 2311 result.addOperands(source); 2312 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 2313 auto sizesAttr = getVectorSubscriptAttr(builder, sizes); 2314 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 2315 result.addTypes( 2316 inferStridedSliceOpResultType(source.getType().cast<VectorType>(), 2317 offsetsAttr, sizesAttr, stridesAttr)); 2318 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); 2319 result.addAttribute(getSizesAttrStrName(), sizesAttr); 2320 result.addAttribute(getStridesAttrStrName(), stridesAttr); 2321 } 2322 2323 LogicalResult ExtractStridedSliceOp::verify() { 2324 auto type = getVectorType(); 2325 auto offsets = offsetsAttr(); 2326 auto sizes = sizesAttr(); 2327 auto strides = stridesAttr(); 2328 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) 2329 return emitOpError("expected offsets, sizes and strides attributes of same size"); 2330 2331 auto shape = type.getShape(); 2332 auto offName = getOffsetsAttrName(); 2333 auto sizesName = getSizesAttrName(); 2334 auto stridesName = getStridesAttrName(); 2335 if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || 2336 failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || 2337 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, 2338 stridesName)) || 2339 failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || 2340 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, 2341 /*halfOpen=*/false, 2342 /*min=*/1)) || 2343 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, 2344 /*halfOpen=*/false)) || 2345 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, 2346 offName, sizesName, 2347 /*halfOpen=*/false))) 2348 return failure(); 2349 2350 auto resultType = 2351 inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); 2352 if (getResult().getType() != resultType) 2353 return emitOpError("expected result type to be ") << resultType; 2354 2355 return success(); 2356 } 2357 2358 // When the source of ExtractStrided comes from a chain of InsertStrided ops try 2359 // to use the source of the InsertStrided ops if we can detect that the 2360 // extracted vector is a subset of one of the vector inserted. 2361 static LogicalResult 2362 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { 2363 // Helper to extract integer out of ArrayAttr. 2364 auto getElement = [](ArrayAttr array, int idx) { 2365 return array[idx].cast<IntegerAttr>().getInt(); 2366 }; 2367 ArrayAttr extractOffsets = op.offsets(); 2368 ArrayAttr extractStrides = op.strides(); 2369 ArrayAttr extractSizes = op.sizes(); 2370 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 2371 while (insertOp) { 2372 if (op.getVectorType().getRank() != 2373 insertOp.getSourceVectorType().getRank()) 2374 return failure(); 2375 ArrayAttr insertOffsets = insertOp.offsets(); 2376 ArrayAttr insertStrides = insertOp.strides(); 2377 // If the rank of extract is greater than the rank of insert, we are likely 2378 // extracting a partial chunk of the vector inserted. 2379 if (extractOffsets.size() > insertOffsets.size()) 2380 return failure(); 2381 bool patialoverlap = false; 2382 bool disjoint = false; 2383 SmallVector<int64_t, 4> offsetDiffs; 2384 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 2385 if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) 2386 return failure(); 2387 int64_t start = getElement(insertOffsets, dim); 2388 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); 2389 int64_t offset = getElement(extractOffsets, dim); 2390 int64_t size = getElement(extractSizes, dim); 2391 // Check if the start of the extract offset is in the interval inserted. 2392 if (start <= offset && offset < end) { 2393 // If the extract interval overlaps but is not fully included we may 2394 // have a partial overlap that will prevent any folding. 2395 if (offset + size > end) 2396 patialoverlap = true; 2397 offsetDiffs.push_back(offset - start); 2398 continue; 2399 } 2400 disjoint = true; 2401 break; 2402 } 2403 // The extract element chunk is a subset of the insert element. 2404 if (!disjoint && !patialoverlap) { 2405 op.setOperand(insertOp.source()); 2406 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2407 OpBuilder b(op.getContext()); 2408 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), 2409 b.getI64ArrayAttr(offsetDiffs)); 2410 return success(); 2411 } 2412 // If the chunk extracted is disjoint from the chunk inserted, keep looking 2413 // in the insert chain. 2414 if (disjoint) 2415 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 2416 else { 2417 // The extracted vector partially overlap the inserted vector, we cannot 2418 // fold. 2419 return failure(); 2420 } 2421 } 2422 return failure(); 2423 } 2424 2425 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2426 if (getVectorType() == getResult().getType()) 2427 return vector(); 2428 if (succeeded(foldExtractStridedOpFromInsertChain(*this))) 2429 return getResult(); 2430 return {}; 2431 } 2432 2433 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { 2434 populateFromInt64AttrArray(offsets(), results); 2435 } 2436 2437 namespace { 2438 2439 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to 2440 // ConstantMaskOp. 2441 class StridedSliceConstantMaskFolder final 2442 : public OpRewritePattern<ExtractStridedSliceOp> { 2443 public: 2444 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2445 2446 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2447 PatternRewriter &rewriter) const override { 2448 // Return if 'extractStridedSliceOp' operand is not defined by a 2449 // ConstantMaskOp. 2450 auto *defOp = extractStridedSliceOp.vector().getDefiningOp(); 2451 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); 2452 if (!constantMaskOp) 2453 return failure(); 2454 // Return if 'extractStridedSliceOp' has non-unit strides. 2455 if (extractStridedSliceOp.hasNonUnitStrides()) 2456 return failure(); 2457 // Gather constant mask dimension sizes. 2458 SmallVector<int64_t, 4> maskDimSizes; 2459 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); 2460 // Gather strided slice offsets and sizes. 2461 SmallVector<int64_t, 4> sliceOffsets; 2462 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets); 2463 SmallVector<int64_t, 4> sliceSizes; 2464 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes); 2465 2466 // Compute slice of vector mask region. 2467 SmallVector<int64_t, 4> sliceMaskDimSizes; 2468 assert(sliceOffsets.size() == maskDimSizes.size()); 2469 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { 2470 int64_t maskDimSize = std::get<0>(it); 2471 int64_t sliceOffset = std::get<1>(it); 2472 int64_t sliceSize = std::get<2>(it); 2473 int64_t sliceMaskDimSize = std::max( 2474 static_cast<int64_t>(0), 2475 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); 2476 sliceMaskDimSizes.push_back(sliceMaskDimSize); 2477 } 2478 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked 2479 // region is a conjunction of mask dim intervals). 2480 if (llvm::is_contained(sliceMaskDimSizes, 0)) 2481 sliceMaskDimSizes.assign(maskDimSizes.size(), 0); 2482 2483 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask 2484 // region. 2485 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 2486 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), 2487 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); 2488 return success(); 2489 } 2490 }; 2491 2492 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. 2493 class StridedSliceConstantFolder final 2494 : public OpRewritePattern<ExtractStridedSliceOp> { 2495 public: 2496 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2497 2498 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2499 PatternRewriter &rewriter) const override { 2500 // Return if 'extractStridedSliceOp' operand is not defined by a 2501 // ConstantOp. 2502 auto constantOp = 2503 extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>(); 2504 if (!constantOp) 2505 return failure(); 2506 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 2507 if (!dense) 2508 return failure(); 2509 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), 2510 dense.getSplatValue<Attribute>()); 2511 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, 2512 newAttr); 2513 return success(); 2514 } 2515 }; 2516 2517 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to 2518 // BroadcastOp(ExtractStrideSliceOp). 2519 class StridedSliceBroadcast final 2520 : public OpRewritePattern<ExtractStridedSliceOp> { 2521 public: 2522 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2523 2524 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2525 PatternRewriter &rewriter) const override { 2526 auto broadcast = op.vector().getDefiningOp<BroadcastOp>(); 2527 if (!broadcast) 2528 return failure(); 2529 auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>(); 2530 unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; 2531 auto dstVecType = op.getType().cast<VectorType>(); 2532 unsigned dstRank = dstVecType.getRank(); 2533 unsigned rankDiff = dstRank - srcRrank; 2534 // Check if the most inner dimensions of the source of the broadcast are the 2535 // same as the destination of the extract. If this is the case we can just 2536 // use a broadcast as the original dimensions are untouched. 2537 bool lowerDimMatch = true; 2538 for (unsigned i = 0; i < srcRrank; i++) { 2539 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { 2540 lowerDimMatch = false; 2541 break; 2542 } 2543 } 2544 Value source = broadcast.source(); 2545 if (!lowerDimMatch) { 2546 // The inner dimensions don't match, it means we need to extract from the 2547 // source of the orignal broadcast and then broadcast the extracted value. 2548 source = rewriter.create<ExtractStridedSliceOp>( 2549 op->getLoc(), source, 2550 getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), 2551 getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), 2552 getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); 2553 } 2554 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); 2555 return success(); 2556 } 2557 }; 2558 2559 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. 2560 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { 2561 public: 2562 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2563 2564 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2565 PatternRewriter &rewriter) const override { 2566 auto splat = op.vector().getDefiningOp<SplatOp>(); 2567 if (!splat) 2568 return failure(); 2569 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input()); 2570 return success(); 2571 } 2572 }; 2573 2574 } // namespace 2575 2576 void ExtractStridedSliceOp::getCanonicalizationPatterns( 2577 RewritePatternSet &results, MLIRContext *context) { 2578 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> 2579 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. 2580 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder, 2581 StridedSliceBroadcast, StridedSliceSplat>(context); 2582 } 2583 2584 //===----------------------------------------------------------------------===// 2585 // TransferReadOp 2586 //===----------------------------------------------------------------------===// 2587 2588 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). 2589 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2590 VectorType vectorType, Value source, 2591 ValueRange indices, AffineMapAttr permutationMapAttr, 2592 /*optional*/ ArrayAttr inBoundsAttr) { 2593 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2594 Value padding = builder.create<arith::ConstantOp>( 2595 result.location, elemType, builder.getZeroAttr(elemType)); 2596 build(builder, result, vectorType, source, indices, permutationMapAttr, 2597 padding, /*mask=*/Value(), inBoundsAttr); 2598 } 2599 2600 /// 2. Builder that sets padding to zero an empty mask (variant without attrs). 2601 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2602 VectorType vectorType, Value source, 2603 ValueRange indices, AffineMap permutationMap, 2604 Optional<ArrayRef<bool>> inBounds) { 2605 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2606 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2607 ? builder.getBoolArrayAttr(inBounds.getValue()) 2608 : ArrayAttr(); 2609 build(builder, result, vectorType, source, indices, permutationMapAttr, 2610 inBoundsAttr); 2611 } 2612 2613 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. 2614 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2615 VectorType vectorType, Value source, 2616 ValueRange indices, Value padding, 2617 Optional<ArrayRef<bool>> inBounds) { 2618 AffineMap permutationMap = getTransferMinorIdentityMap( 2619 source.getType().cast<ShapedType>(), vectorType); 2620 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2621 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2622 ? builder.getBoolArrayAttr(inBounds.getValue()) 2623 : ArrayAttr(); 2624 build(builder, result, vectorType, source, indices, permutationMapAttr, 2625 padding, 2626 /*mask=*/Value(), inBoundsAttr); 2627 } 2628 2629 /// 4. Builder that sets padding to zero and permutation map to 2630 /// 'getMinorIdentityMap'. 2631 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2632 VectorType vectorType, Value source, 2633 ValueRange indices, 2634 Optional<ArrayRef<bool>> inBounds) { 2635 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2636 Value padding = builder.create<arith::ConstantOp>( 2637 result.location, elemType, builder.getZeroAttr(elemType)); 2638 build(builder, result, vectorType, source, indices, padding, inBounds); 2639 } 2640 2641 template <typename EmitFun> 2642 static LogicalResult verifyPermutationMap(AffineMap permutationMap, 2643 EmitFun emitOpError) { 2644 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); 2645 for (auto expr : permutationMap.getResults()) { 2646 auto dim = expr.dyn_cast<AffineDimExpr>(); 2647 auto zero = expr.dyn_cast<AffineConstantExpr>(); 2648 if (zero) { 2649 if (zero.getValue() != 0) { 2650 return emitOpError( 2651 "requires a projected permutation_map (at most one dim or the zero " 2652 "constant can appear in each result)"); 2653 } 2654 continue; 2655 } 2656 if (!dim) { 2657 return emitOpError("requires a projected permutation_map (at most one " 2658 "dim or the zero constant can appear in each result)"); 2659 } 2660 if (seen[dim.getPosition()]) { 2661 return emitOpError( 2662 "requires a permutation_map that is a permutation (found one dim " 2663 "used more than once)"); 2664 } 2665 seen[dim.getPosition()] = true; 2666 } 2667 return success(); 2668 } 2669 2670 static LogicalResult 2671 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, 2672 VectorType vectorType, VectorType maskType, 2673 AffineMap permutationMap, ArrayAttr inBounds) { 2674 if (op->hasAttr("masked")) { 2675 return op->emitOpError("masked attribute has been removed. " 2676 "Use in_bounds instead."); 2677 } 2678 2679 if (!shapedType.isa<MemRefType, RankedTensorType>()) 2680 return op->emitOpError( 2681 "requires source to be a memref or ranked tensor type"); 2682 2683 auto elementType = shapedType.getElementType(); 2684 DataLayout dataLayout = DataLayout::closest(op); 2685 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) { 2686 // Memref or tensor has vector element type. 2687 unsigned sourceVecSize = 2688 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * 2689 vectorElementType.getShape().back(); 2690 unsigned resultVecSize = 2691 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * 2692 vectorType.getShape().back(); 2693 if (resultVecSize % sourceVecSize != 0) 2694 return op->emitOpError( 2695 "requires the bitwidth of the minor 1-D vector to be an integral " 2696 "multiple of the bitwidth of the minor 1-D vector of the source"); 2697 2698 unsigned sourceVecEltRank = vectorElementType.getRank(); 2699 unsigned resultVecRank = vectorType.getRank(); 2700 if (sourceVecEltRank > resultVecRank) 2701 return op->emitOpError( 2702 "requires source vector element and vector result ranks to match."); 2703 unsigned rankOffset = resultVecRank - sourceVecEltRank; 2704 // Check that permutation map results match 'rankOffset' of vector type. 2705 if (permutationMap.getNumResults() != rankOffset) 2706 return op->emitOpError("requires a permutation_map with result dims of " 2707 "the same rank as the vector type"); 2708 2709 if (maskType) 2710 return op->emitOpError("does not support masks with vector element type"); 2711 } else { 2712 // Memref or tensor has scalar element type. 2713 unsigned minorSize = 2714 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); 2715 unsigned resultVecSize = 2716 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; 2717 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) 2718 return op->emitOpError( 2719 "requires the bitwidth of the minor 1-D vector to be an integral " 2720 "multiple of the bitwidth of the source element type"); 2721 2722 // Check that permutation map results match rank of vector type. 2723 if (permutationMap.getNumResults() != vectorType.getRank()) 2724 return op->emitOpError("requires a permutation_map with result dims of " 2725 "the same rank as the vector type"); 2726 2727 VectorType expectedMaskType = 2728 vector::detail::transferMaskType(vectorType, permutationMap); 2729 if (maskType && expectedMaskType != maskType) 2730 return op->emitOpError("expects mask type consistent with permutation " 2731 "map: ") 2732 << maskType; 2733 } 2734 2735 if (permutationMap.getNumSymbols() != 0) 2736 return op->emitOpError("requires permutation_map without symbols"); 2737 2738 if (permutationMap.getNumInputs() != shapedType.getRank()) 2739 return op->emitOpError("requires a permutation_map with input dims of the " 2740 "same rank as the source type"); 2741 2742 if (inBounds) { 2743 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size())) 2744 return op->emitOpError("expects the optional in_bounds attr of same rank " 2745 "as permutation_map results: ") 2746 << AffineMapAttr::get(permutationMap) 2747 << " vs inBounds of size: " << inBounds.size(); 2748 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) 2749 if (permutationMap.getResult(i).isa<AffineConstantExpr>() && 2750 !inBounds.getValue()[i].cast<BoolAttr>().getValue()) 2751 return op->emitOpError("requires broadcast dimensions to be in-bounds"); 2752 } 2753 2754 return success(); 2755 } 2756 2757 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { 2758 SmallVector<StringRef, 3> elidedAttrs; 2759 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); 2760 if (op.permutation_map().isMinorIdentity()) 2761 elidedAttrs.push_back(op.getPermutationMapAttrStrName()); 2762 bool elideInBounds = true; 2763 if (auto inBounds = op.in_bounds()) { 2764 for (auto attr : *inBounds) { 2765 if (attr.template cast<BoolAttr>().getValue()) { 2766 elideInBounds = false; 2767 break; 2768 } 2769 } 2770 } 2771 if (elideInBounds) 2772 elidedAttrs.push_back(op.getInBoundsAttrStrName()); 2773 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 2774 } 2775 2776 void TransferReadOp::print(OpAsmPrinter &p) { 2777 p << " " << source() << "[" << indices() << "], " << padding(); 2778 if (mask()) 2779 p << ", " << mask(); 2780 printTransferAttrs(p, *this); 2781 p << " : " << getShapedType() << ", " << getVectorType(); 2782 } 2783 2784 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { 2785 auto &builder = parser.getBuilder(); 2786 SMLoc typesLoc; 2787 OpAsmParser::OperandType sourceInfo; 2788 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 2789 OpAsmParser::OperandType paddingInfo; 2790 SmallVector<Type, 2> types; 2791 OpAsmParser::OperandType maskInfo; 2792 // Parsing with support for paddingValue. 2793 if (parser.parseOperand(sourceInfo) || 2794 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 2795 parser.parseComma() || parser.parseOperand(paddingInfo)) 2796 return failure(); 2797 ParseResult hasMask = parser.parseOptionalComma(); 2798 if (hasMask.succeeded()) { 2799 parser.parseOperand(maskInfo); 2800 } 2801 if (parser.parseOptionalAttrDict(result.attributes) || 2802 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 2803 return failure(); 2804 if (types.size() != 2) 2805 return parser.emitError(typesLoc, "requires two types"); 2806 auto indexType = builder.getIndexType(); 2807 auto shapedType = types[0].dyn_cast<ShapedType>(); 2808 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 2809 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 2810 VectorType vectorType = types[1].dyn_cast<VectorType>(); 2811 if (!vectorType) 2812 return parser.emitError(typesLoc, "requires vector type"); 2813 auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName(); 2814 Attribute mapAttr = result.attributes.get(permutationAttrName); 2815 if (!mapAttr) { 2816 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 2817 // Update `mapAttr` that is used later to determine mask type. 2818 mapAttr = AffineMapAttr::get(permMap); 2819 result.attributes.set(permutationAttrName, mapAttr); 2820 } 2821 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || 2822 parser.resolveOperands(indexInfo, indexType, result.operands) || 2823 parser.resolveOperand(paddingInfo, shapedType.getElementType(), 2824 result.operands)) 2825 return failure(); 2826 if (hasMask.succeeded()) { 2827 if (shapedType.getElementType().dyn_cast<VectorType>()) 2828 return parser.emitError( 2829 maskInfo.location, "does not support masks with vector element type"); 2830 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue(); 2831 // Instead of adding the mask type as an op type, compute it based on the 2832 // vector type and the permutation map (to keep the type signature small). 2833 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); 2834 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 2835 return failure(); 2836 } 2837 result.addAttribute( 2838 TransferReadOp::getOperandSegmentSizeAttr(), 2839 builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1, 2840 static_cast<int32_t>(hasMask.succeeded())})); 2841 return parser.addTypeToList(vectorType, result.types); 2842 } 2843 2844 LogicalResult TransferReadOp::verify() { 2845 // Consistency of elemental types in source and vector. 2846 ShapedType shapedType = getShapedType(); 2847 VectorType vectorType = getVectorType(); 2848 VectorType maskType = getMaskType(); 2849 auto paddingType = padding().getType(); 2850 auto permutationMap = permutation_map(); 2851 auto sourceElementType = shapedType.getElementType(); 2852 2853 if (static_cast<int64_t>(indices().size()) != shapedType.getRank()) 2854 return emitOpError("requires ") << shapedType.getRank() << " indices"; 2855 2856 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 2857 shapedType, vectorType, maskType, permutationMap, 2858 in_bounds() ? *in_bounds() : ArrayAttr()))) 2859 return failure(); 2860 2861 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) { 2862 // Source has vector element type. 2863 // Check that 'sourceVectorElementType' and 'paddingType' types match. 2864 if (sourceVectorElementType != paddingType) 2865 return emitOpError( 2866 "requires source element type and padding type to match."); 2867 2868 } else { 2869 // Check that 'paddingType' is valid to store in a vector type. 2870 if (!VectorType::isValidElementType(paddingType)) 2871 return emitOpError("requires valid padding vector elemental type"); 2872 2873 // Check that padding type and vector element types match. 2874 if (paddingType != sourceElementType) 2875 return emitOpError( 2876 "requires formal padding and source of the same elemental type"); 2877 } 2878 2879 return verifyPermutationMap(permutationMap, 2880 [&](Twine t) { return emitOpError(t); }); 2881 } 2882 2883 /// This is a common class used for patterns of the form 2884 /// ``` 2885 /// someop(memrefcast) -> someop 2886 /// ``` 2887 /// It folds the source of the memref.cast into the root operation directly. 2888 static LogicalResult foldMemRefCast(Operation *op) { 2889 bool folded = false; 2890 for (OpOperand &operand : op->getOpOperands()) { 2891 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 2892 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 2893 operand.set(castOp.getOperand()); 2894 folded = true; 2895 } 2896 } 2897 return success(folded); 2898 } 2899 2900 static LogicalResult foldTensorCast(Operation *op) { 2901 bool folded = false; 2902 for (OpOperand &operand : op->getOpOperands()) { 2903 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 2904 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 2905 operand.set(castOp.getOperand()); 2906 folded = true; 2907 } 2908 } 2909 return success(folded); 2910 } 2911 2912 template <typename TransferOp> 2913 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { 2914 // TODO: support more aggressive createOrFold on: 2915 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` 2916 if (op.getShapedType().isDynamicDim(indicesIdx)) 2917 return false; 2918 Value index = op.indices()[indicesIdx]; 2919 auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>(); 2920 if (!cstOp) 2921 return false; 2922 2923 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); 2924 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); 2925 2926 return cstOp.value() + vectorSize <= sourceSize; 2927 } 2928 2929 template <typename TransferOp> 2930 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { 2931 // TODO: support 0-d corner case. 2932 // TODO: Be less conservative. 2933 if (op.getTransferRank() == 0) 2934 return failure(); 2935 AffineMap permutationMap = op.permutation_map(); 2936 bool changed = false; 2937 SmallVector<bool, 4> newInBounds; 2938 newInBounds.reserve(op.getTransferRank()); 2939 for (unsigned i = 0; i < op.getTransferRank(); ++i) { 2940 // Already marked as in-bounds, nothing to see here. 2941 if (op.isDimInBounds(i)) { 2942 newInBounds.push_back(true); 2943 continue; 2944 } 2945 // Currently out-of-bounds, check whether we can statically determine it is 2946 // inBounds. 2947 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>(); 2948 assert(dimExpr && "Broadcast dims must be in-bounds"); 2949 auto inBounds = 2950 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); 2951 newInBounds.push_back(inBounds); 2952 // We commit the pattern if it is "more inbounds". 2953 changed |= inBounds; 2954 } 2955 if (!changed) 2956 return failure(); 2957 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2958 OpBuilder b(op.getContext()); 2959 op->setAttr(TransferOp::getInBoundsAttrStrName(), 2960 b.getBoolArrayAttr(newInBounds)); 2961 return success(); 2962 } 2963 2964 /// ``` 2965 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 2966 /// : vector<1x4xf32>, tensor<4x4xf32> 2967 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} 2968 /// : tensor<4x4xf32>, vector<1x4xf32> 2969 /// ``` 2970 /// -> Folds into 2971 /// ``` 2972 /// %v0 2973 /// ``` 2974 static Value foldRAW(TransferReadOp readOp) { 2975 if (!readOp.getShapedType().isa<RankedTensorType>()) 2976 return {}; 2977 auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>(); 2978 while (defWrite) { 2979 if (checkSameValueRAW(defWrite, readOp)) 2980 return defWrite.vector(); 2981 if (!isDisjointTransferIndices( 2982 cast<VectorTransferOpInterface>(defWrite.getOperation()), 2983 cast<VectorTransferOpInterface>(readOp.getOperation()))) 2984 break; 2985 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 2986 } 2987 return {}; 2988 } 2989 2990 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { 2991 if (Value vec = foldRAW(*this)) 2992 return vec; 2993 /// transfer_read(memrefcast) -> transfer_read 2994 if (succeeded(foldTransferInBoundsAttribute(*this))) 2995 return getResult(); 2996 if (succeeded(foldMemRefCast(*this))) 2997 return getResult(); 2998 if (succeeded(foldTensorCast(*this))) 2999 return getResult(); 3000 return OpFoldResult(); 3001 } 3002 3003 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { 3004 return llvm::to_vector<4>(getVectorType().getShape()); 3005 } 3006 3007 void TransferReadOp::getEffects( 3008 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3009 &effects) { 3010 if (getShapedType().isa<MemRefType>()) 3011 effects.emplace_back(MemoryEffects::Read::get(), source(), 3012 SideEffects::DefaultResource::get()); 3013 } 3014 3015 namespace { 3016 /// Fold transfer_reads of a tensor.extract_slice op. E.g.: 3017 /// 3018 /// ``` 3019 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] 3020 /// : tensor<?x?xf32> to tensor<?x?xf32> 3021 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} 3022 /// : tensor<?x?xf32>, vector<4x5xf32> 3023 /// ``` 3024 /// is rewritten to: 3025 /// ``` 3026 /// %p0 = arith.addi %a, %e : index 3027 /// %p1 = arith.addi %b, %f : index 3028 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} 3029 /// : tensor<?x?xf32>, vector<4x5xf32> 3030 /// ``` 3031 struct FoldExtractSliceIntoTransferRead 3032 : public OpRewritePattern<TransferReadOp> { 3033 public: 3034 using OpRewritePattern<TransferReadOp>::OpRewritePattern; 3035 3036 LogicalResult matchAndRewrite(TransferReadOp xferOp, 3037 PatternRewriter &rewriter) const override { 3038 // TODO: support 0-d corner case. 3039 if (xferOp.getTransferRank() == 0) 3040 return failure(); 3041 if (xferOp.hasOutOfBoundsDim()) 3042 return failure(); 3043 if (!xferOp.permutation_map().isIdentity()) 3044 return failure(); 3045 if (xferOp.mask()) 3046 return failure(); 3047 auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>(); 3048 if (!extractOp) 3049 return failure(); 3050 if (!extractOp.hasUnitStride()) 3051 return failure(); 3052 3053 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3054 // dims are exactly the leading dims. I.e. the following is illegal: 3055 // ``` 3056 // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : 3057 // tensor<2x1x4xf32> to tensor<2x4xf32> 3058 // %1 = vector.transfer_read %0[0,0], %cst : 3059 // tensor<2x4xf32>, vector<2x4xf32> 3060 // ``` 3061 // 3062 // Cannot fold into: 3063 // ``` 3064 // %0 = vector.transfer_read %t[0,0,0], %cst : 3065 // tensor<2x1x4xf32>, vector<2x4xf32> 3066 // ``` 3067 // For this, check the trailing `vectorRank` dims of the extract_slice 3068 // result tensor match the trailing dims of the inferred result tensor. 3069 int64_t rankReduced = 3070 extractOp.getSourceType().getRank() - extractOp.getType().getRank(); 3071 int64_t vectorRank = xferOp.getVectorType().getRank(); 3072 RankedTensorType inferredDestTensorType = 3073 tensor::ExtractSliceOp::inferResultType( 3074 extractOp.getSourceType(), extractOp.getMixedOffsets(), 3075 extractOp.getMixedSizes(), extractOp.getMixedStrides()); 3076 auto actualDestTensorShape = extractOp.getType().getShape(); 3077 if (rankReduced > 0 && 3078 actualDestTensorShape.take_back(vectorRank) != 3079 inferredDestTensorType.getShape().take_back(vectorRank)) 3080 return failure(); 3081 3082 SmallVector<Value> newIndices; 3083 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced 3084 // indices first. 3085 for (int64_t i = 0; i < rankReduced; ++i) { 3086 OpFoldResult offset = extractOp.getMixedOffsets()[i]; 3087 newIndices.push_back(getValueOrCreateConstantIndexOp( 3088 rewriter, extractOp.getLoc(), offset)); 3089 } 3090 for (const auto &it : llvm::enumerate(xferOp.indices())) { 3091 OpFoldResult offset = 3092 extractOp.getMixedOffsets()[it.index() + rankReduced]; 3093 newIndices.push_back(rewriter.create<arith::AddIOp>( 3094 xferOp->getLoc(), it.value(), 3095 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), 3096 offset))); 3097 } 3098 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3099 rewriter.replaceOpWithNewOp<TransferReadOp>( 3100 xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, 3101 xferOp.padding(), ArrayRef<bool>{inBounds}); 3102 3103 return success(); 3104 } 3105 }; 3106 } // namespace 3107 3108 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3109 MLIRContext *context) { 3110 results.add<FoldExtractSliceIntoTransferRead>(context); 3111 } 3112 3113 //===----------------------------------------------------------------------===// 3114 // TransferWriteOp 3115 //===----------------------------------------------------------------------===// 3116 3117 /// 1. Builder with type inference. 3118 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3119 Value vector, Value dest, ValueRange indices, 3120 AffineMapAttr permutationMapAttr, 3121 /*optional*/ Value mask, 3122 /*optional*/ ArrayAttr inBoundsAttr) { 3123 Type resultType = dest.getType().dyn_cast<RankedTensorType>(); 3124 build(builder, result, resultType, vector, dest, indices, permutationMapAttr, 3125 mask, inBoundsAttr); 3126 } 3127 3128 /// 2. Builder with type inference that sets an empty mask (variant with attrs). 3129 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3130 Value vector, Value dest, ValueRange indices, 3131 AffineMapAttr permutationMapAttr, 3132 /*optional*/ ArrayAttr inBoundsAttr) { 3133 build(builder, result, vector, dest, indices, permutationMapAttr, 3134 /*mask=*/Value(), inBoundsAttr); 3135 } 3136 3137 /// 3. Builder with type inference that sets an empty mask (variant without 3138 /// attrs) 3139 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3140 Value vector, Value dest, ValueRange indices, 3141 AffineMap permutationMap, 3142 Optional<ArrayRef<bool>> inBounds) { 3143 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 3144 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 3145 ? builder.getBoolArrayAttr(inBounds.getValue()) 3146 : ArrayAttr(); 3147 build(builder, result, vector, dest, indices, permutationMapAttr, 3148 /*mask=*/Value(), inBoundsAttr); 3149 } 3150 3151 /// 4. Builder with type inference that sets an empty mask and sets permutation 3152 /// map to 'getMinorIdentityMap'. 3153 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3154 Value vector, Value dest, ValueRange indices, 3155 Optional<ArrayRef<bool>> inBounds) { 3156 auto vectorType = vector.getType().cast<VectorType>(); 3157 AffineMap permutationMap = getTransferMinorIdentityMap( 3158 dest.getType().cast<ShapedType>(), vectorType); 3159 build(builder, result, vector, dest, indices, permutationMap, inBounds); 3160 } 3161 3162 ParseResult TransferWriteOp::parse(OpAsmParser &parser, 3163 OperationState &result) { 3164 auto &builder = parser.getBuilder(); 3165 SMLoc typesLoc; 3166 OpAsmParser::OperandType vectorInfo, sourceInfo; 3167 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 3168 SmallVector<Type, 2> types; 3169 OpAsmParser::OperandType maskInfo; 3170 if (parser.parseOperand(vectorInfo) || parser.parseComma() || 3171 parser.parseOperand(sourceInfo) || 3172 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) 3173 return failure(); 3174 ParseResult hasMask = parser.parseOptionalComma(); 3175 if (hasMask.succeeded() && parser.parseOperand(maskInfo)) 3176 return failure(); 3177 if (parser.parseOptionalAttrDict(result.attributes) || 3178 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 3179 return failure(); 3180 if (types.size() != 2) 3181 return parser.emitError(typesLoc, "requires two types"); 3182 auto indexType = builder.getIndexType(); 3183 VectorType vectorType = types[0].dyn_cast<VectorType>(); 3184 if (!vectorType) 3185 return parser.emitError(typesLoc, "requires vector type"); 3186 ShapedType shapedType = types[1].dyn_cast<ShapedType>(); 3187 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 3188 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 3189 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName(); 3190 auto attr = result.attributes.get(permutationAttrName); 3191 if (!attr) { 3192 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 3193 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); 3194 } 3195 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || 3196 parser.resolveOperand(sourceInfo, shapedType, result.operands) || 3197 parser.resolveOperands(indexInfo, indexType, result.operands)) 3198 return failure(); 3199 if (hasMask.succeeded()) { 3200 if (shapedType.getElementType().dyn_cast<VectorType>()) 3201 return parser.emitError( 3202 maskInfo.location, "does not support masks with vector element type"); 3203 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); 3204 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 3205 return failure(); 3206 } 3207 result.addAttribute( 3208 TransferWriteOp::getOperandSegmentSizeAttr(), 3209 builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()), 3210 static_cast<int32_t>(hasMask.succeeded())})); 3211 return failure(shapedType.isa<RankedTensorType>() && 3212 parser.addTypeToList(shapedType, result.types)); 3213 } 3214 3215 void TransferWriteOp::print(OpAsmPrinter &p) { 3216 p << " " << vector() << ", " << source() << "[" << indices() << "]"; 3217 if (mask()) 3218 p << ", " << mask(); 3219 printTransferAttrs(p, *this); 3220 p << " : " << getVectorType() << ", " << getShapedType(); 3221 } 3222 3223 LogicalResult TransferWriteOp::verify() { 3224 // Consistency of elemental types in shape and vector. 3225 ShapedType shapedType = getShapedType(); 3226 VectorType vectorType = getVectorType(); 3227 VectorType maskType = getMaskType(); 3228 auto permutationMap = permutation_map(); 3229 3230 if (llvm::size(indices()) != shapedType.getRank()) 3231 return emitOpError("requires ") << shapedType.getRank() << " indices"; 3232 3233 // We do not allow broadcast dimensions on TransferWriteOps for the moment, 3234 // as the semantics is unclear. This can be revisited later if necessary. 3235 if (hasBroadcastDim()) 3236 return emitOpError("should not have broadcast dimensions"); 3237 3238 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), 3239 shapedType, vectorType, maskType, permutationMap, 3240 in_bounds() ? *in_bounds() : ArrayAttr()))) 3241 return failure(); 3242 3243 return verifyPermutationMap(permutationMap, 3244 [&](Twine t) { return emitOpError(t); }); 3245 } 3246 3247 /// Fold: 3248 /// ``` 3249 /// %t1 = ... 3250 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : 3251 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3252 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} : 3253 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3254 /// ``` 3255 /// 3256 /// into: 3257 /// 3258 /// ``` 3259 /// %t0 3260 /// ``` 3261 /// 3262 /// The producer of t1 may or may not be DCE'd depending on whether it is a 3263 /// block argument or has side effects. 3264 static LogicalResult foldReadInitWrite(TransferWriteOp write, 3265 ArrayRef<Attribute>, 3266 SmallVectorImpl<OpFoldResult> &results) { 3267 // TODO: support 0-d corner case. 3268 if (write.getTransferRank() == 0) 3269 return failure(); 3270 auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>(); 3271 // If not operating on tensors, bail. 3272 if (!rankedTensorType) 3273 return failure(); 3274 // If no read, bail. 3275 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3276 if (!read) 3277 return failure(); 3278 // TODO: support 0-d corner case. 3279 if (read.getTransferRank() == 0) 3280 return failure(); 3281 // For now, only accept minor identity. Future: composition is minor identity. 3282 if (!read.permutation_map().isMinorIdentity() || 3283 !write.permutation_map().isMinorIdentity()) 3284 return failure(); 3285 // Bail on mismatching ranks. 3286 if (read.getTransferRank() != write.getTransferRank()) 3287 return failure(); 3288 // Bail on potential out-of-bounds accesses. 3289 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) 3290 return failure(); 3291 // Tensor types must be the same. 3292 if (read.source().getType() != rankedTensorType) 3293 return failure(); 3294 // Vector types must be the same. 3295 if (read.getVectorType() != write.getVectorType()) 3296 return failure(); 3297 // Vector and Tensor shapes must match. 3298 if (read.getVectorType().getShape() != rankedTensorType.getShape()) 3299 return failure(); 3300 // If any index is nonzero. 3301 auto isNotConstantZero = [](Value v) { 3302 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>(); 3303 return !cstOp || cstOp.value() != 0; 3304 }; 3305 if (llvm::any_of(read.indices(), isNotConstantZero) || 3306 llvm::any_of(write.indices(), isNotConstantZero)) 3307 return failure(); 3308 // Success. 3309 results.push_back(read.source()); 3310 return success(); 3311 } 3312 3313 static bool checkSameValueWAR(vector::TransferReadOp read, 3314 vector::TransferWriteOp write) { 3315 return read.source() == write.source() && read.indices() == write.indices() && 3316 read.permutation_map() == write.permutation_map() && 3317 read.getVectorType() == write.getVectorType() && !read.mask() && 3318 !write.mask(); 3319 } 3320 /// Fold transfer_write write after read: 3321 /// ``` 3322 /// %t0 = ... 3323 /// %v = vector.transfer_read %t0[%c0...] : 3324 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3325 /// %t1 = vector.transfer_write %v, %t0[%c0...] : 3326 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3327 /// ``` 3328 /// 3329 /// into: 3330 /// 3331 /// ``` 3332 /// %t0 3333 /// ``` 3334 static LogicalResult foldWAR(TransferWriteOp write, 3335 SmallVectorImpl<OpFoldResult> &results) { 3336 if (!write.source().getType().isa<RankedTensorType>()) 3337 return failure(); 3338 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3339 if (!read) 3340 return failure(); 3341 3342 if (!checkSameValueWAR(read, write)) 3343 return failure(); 3344 results.push_back(read.source()); 3345 return success(); 3346 } 3347 3348 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands, 3349 SmallVectorImpl<OpFoldResult> &results) { 3350 if (succeeded(foldReadInitWrite(*this, operands, results))) 3351 return success(); 3352 if (succeeded(foldWAR(*this, results))) 3353 return success(); 3354 if (succeeded(foldTransferInBoundsAttribute(*this))) 3355 return success(); 3356 return foldMemRefCast(*this); 3357 } 3358 3359 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { 3360 return llvm::to_vector<4>(getVectorType().getShape()); 3361 } 3362 3363 void TransferWriteOp::getEffects( 3364 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3365 &effects) { 3366 if (getShapedType().isa<MemRefType>()) 3367 effects.emplace_back(MemoryEffects::Write::get(), source(), 3368 SideEffects::DefaultResource::get()); 3369 } 3370 3371 namespace { 3372 /// Remove dead transfer write from the SSA chain so that it an be eliminated by 3373 /// DCE 3374 /// ``` 3375 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3376 /// : vector<1x4xf32>, tensor<4x4xf32> 3377 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} 3378 /// : vector<1x4xf32>, tensor<4x4xf32> 3379 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3380 /// : vector<1x4xf32>, tensor<4x4xf32> 3381 /// ``` 3382 /// 3383 /// into: 3384 /// 3385 /// ``` 3386 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3387 /// : vector<1x4xf32>, tensor<4x4xf32> 3388 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]} 3389 /// : vector<1x4xf32>, tensor<4x4xf32> 3390 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3391 /// : vector<1x4xf32>, tensor<4x4xf32> 3392 /// ``` 3393 /// 3394 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have 3395 /// any other uses. 3396 class FoldWaw final : public OpRewritePattern<TransferWriteOp> { 3397 public: 3398 using OpRewritePattern<TransferWriteOp>::OpRewritePattern; 3399 LogicalResult matchAndRewrite(TransferWriteOp writeOp, 3400 PatternRewriter &rewriter) const override { 3401 if (!writeOp.getShapedType().isa<RankedTensorType>()) 3402 return failure(); 3403 vector::TransferWriteOp writeToModify = writeOp; 3404 3405 auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>(); 3406 while (defWrite) { 3407 if (checkSameValueWAW(writeOp, defWrite)) { 3408 writeToModify.sourceMutable().assign(defWrite.source()); 3409 return success(); 3410 } 3411 if (!isDisjointTransferIndices( 3412 cast<VectorTransferOpInterface>(defWrite.getOperation()), 3413 cast<VectorTransferOpInterface>(writeOp.getOperation()))) 3414 break; 3415 // If the previous write op doesn't have any other use we an safely look 3416 // at the previous store to see if it can be removed. 3417 if (!defWrite->hasOneUse()) 3418 break; 3419 writeToModify = defWrite; 3420 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 3421 } 3422 return failure(); 3423 } 3424 }; 3425 3426 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write 3427 /// could directly write to the insert_slice's destination. E.g.: 3428 /// 3429 /// ``` 3430 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} 3431 /// : vector<4x5xf32>, tensor<4x5xf32> 3432 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] 3433 /// : tensor<4x5xf32> into tensor<?x?xf32> 3434 /// ``` 3435 /// is rewritten to: 3436 /// ``` 3437 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} 3438 /// : vector<4x5xf32>, tensor<?x?xf32> 3439 /// ``` 3440 struct FoldInsertSliceIntoTransferWrite 3441 : public OpRewritePattern<tensor::InsertSliceOp> { 3442 public: 3443 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 3444 3445 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 3446 PatternRewriter &rewriter) const override { 3447 if (!insertOp.hasUnitStride()) 3448 return failure(); 3449 3450 auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>(); 3451 if (!xferOp) 3452 return failure(); 3453 // TODO: support 0-d corner case. 3454 if (xferOp.getTransferRank() == 0) 3455 return failure(); 3456 3457 if (xferOp.hasOutOfBoundsDim()) 3458 return failure(); 3459 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) 3460 return failure(); 3461 if (xferOp.mask()) 3462 return failure(); 3463 // Fold only if the TransferWriteOp completely overwrites the `source` with 3464 // a vector. I.e., the result of the TransferWriteOp is a new tensor whose 3465 // content is the data of the vector. 3466 if (!llvm::equal(xferOp.getVectorType().getShape(), 3467 xferOp.getShapedType().getShape())) 3468 return failure(); 3469 if (!xferOp.permutation_map().isIdentity()) 3470 return failure(); 3471 3472 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3473 // dims are exactly the leading dims. I.e. the following is illegal: 3474 // ``` 3475 // %0 = vector.transfer_write %v, %t[0,0], %cst : 3476 // vector<2x4xf32>, tensor<2x4xf32> 3477 // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : 3478 // tensor<2x4xf32> into tensor<2x1x4xf32> 3479 // ``` 3480 // 3481 // Cannot fold into: 3482 // ``` 3483 // %0 = vector.transfer_write %v, %t[0,0,0], %cst : 3484 // vector<2x4xf32>, tensor<2x1x4xf32> 3485 // ``` 3486 // For this, check the trailing `vectorRank` dims of the insert_slice result 3487 // tensor match the trailing dims of the inferred result tensor. 3488 int64_t rankReduced = 3489 insertOp.getType().getRank() - insertOp.getSourceType().getRank(); 3490 int64_t vectorRank = xferOp.getVectorType().getRank(); 3491 RankedTensorType inferredSourceTensorType = 3492 tensor::ExtractSliceOp::inferResultType( 3493 insertOp.getType(), insertOp.getMixedOffsets(), 3494 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 3495 auto actualSourceTensorShape = insertOp.getSourceType().getShape(); 3496 if (rankReduced > 0 && 3497 actualSourceTensorShape.take_back(vectorRank) != 3498 inferredSourceTensorType.getShape().take_back(vectorRank)) 3499 return failure(); 3500 3501 SmallVector<Value> indices = getValueOrCreateConstantIndexOp( 3502 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); 3503 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3504 rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(), 3505 insertOp.dest(), indices, 3506 ArrayRef<bool>{inBounds}); 3507 return success(); 3508 } 3509 }; 3510 } // namespace 3511 3512 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, 3513 MLIRContext *context) { 3514 results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context); 3515 } 3516 3517 //===----------------------------------------------------------------------===// 3518 // LoadOp 3519 //===----------------------------------------------------------------------===// 3520 3521 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, 3522 MemRefType memRefTy) { 3523 if (!isLastMemrefDimUnitStride(memRefTy)) 3524 return op->emitOpError("most minor memref dim must have unit stride"); 3525 return success(); 3526 } 3527 3528 LogicalResult vector::LoadOp::verify() { 3529 VectorType resVecTy = getVectorType(); 3530 MemRefType memRefTy = getMemRefType(); 3531 3532 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) 3533 return failure(); 3534 3535 // Checks for vector memrefs. 3536 Type memElemTy = memRefTy.getElementType(); 3537 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3538 if (memVecTy != resVecTy) 3539 return emitOpError("base memref and result vector types should match"); 3540 memElemTy = memVecTy.getElementType(); 3541 } 3542 3543 if (resVecTy.getElementType() != memElemTy) 3544 return emitOpError("base and result element types should match"); 3545 if (llvm::size(indices()) != memRefTy.getRank()) 3546 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 3547 return success(); 3548 } 3549 3550 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) { 3551 if (succeeded(foldMemRefCast(*this))) 3552 return getResult(); 3553 return OpFoldResult(); 3554 } 3555 3556 //===----------------------------------------------------------------------===// 3557 // StoreOp 3558 //===----------------------------------------------------------------------===// 3559 3560 LogicalResult vector::StoreOp::verify() { 3561 VectorType valueVecTy = getVectorType(); 3562 MemRefType memRefTy = getMemRefType(); 3563 3564 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) 3565 return failure(); 3566 3567 // Checks for vector memrefs. 3568 Type memElemTy = memRefTy.getElementType(); 3569 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3570 if (memVecTy != valueVecTy) 3571 return emitOpError( 3572 "base memref and valueToStore vector types should match"); 3573 memElemTy = memVecTy.getElementType(); 3574 } 3575 3576 if (valueVecTy.getElementType() != memElemTy) 3577 return emitOpError("base and valueToStore element type should match"); 3578 if (llvm::size(indices()) != memRefTy.getRank()) 3579 return emitOpError("requires ") << memRefTy.getRank() << " indices"; 3580 return success(); 3581 } 3582 3583 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands, 3584 SmallVectorImpl<OpFoldResult> &results) { 3585 return foldMemRefCast(*this); 3586 } 3587 3588 //===----------------------------------------------------------------------===// 3589 // MaskedLoadOp 3590 //===----------------------------------------------------------------------===// 3591 3592 LogicalResult MaskedLoadOp::verify() { 3593 VectorType maskVType = getMaskVectorType(); 3594 VectorType passVType = getPassThruVectorType(); 3595 VectorType resVType = getVectorType(); 3596 MemRefType memType = getMemRefType(); 3597 3598 if (resVType.getElementType() != memType.getElementType()) 3599 return emitOpError("base and result element type should match"); 3600 if (llvm::size(indices()) != memType.getRank()) 3601 return emitOpError("requires ") << memType.getRank() << " indices"; 3602 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3603 return emitOpError("expected result dim to match mask dim"); 3604 if (resVType != passVType) 3605 return emitOpError("expected pass_thru of same type as result type"); 3606 return success(); 3607 } 3608 3609 namespace { 3610 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { 3611 public: 3612 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern; 3613 LogicalResult matchAndRewrite(MaskedLoadOp load, 3614 PatternRewriter &rewriter) const override { 3615 switch (get1DMaskFormat(load.mask())) { 3616 case MaskFormat::AllTrue: 3617 rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(), 3618 load.base(), load.indices()); 3619 return success(); 3620 case MaskFormat::AllFalse: 3621 rewriter.replaceOp(load, load.pass_thru()); 3622 return success(); 3623 case MaskFormat::Unknown: 3624 return failure(); 3625 } 3626 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad"); 3627 } 3628 }; 3629 } // namespace 3630 3631 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3632 MLIRContext *context) { 3633 results.add<MaskedLoadFolder>(context); 3634 } 3635 3636 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) { 3637 if (succeeded(foldMemRefCast(*this))) 3638 return getResult(); 3639 return OpFoldResult(); 3640 } 3641 3642 //===----------------------------------------------------------------------===// 3643 // MaskedStoreOp 3644 //===----------------------------------------------------------------------===// 3645 3646 LogicalResult MaskedStoreOp::verify() { 3647 VectorType maskVType = getMaskVectorType(); 3648 VectorType valueVType = getVectorType(); 3649 MemRefType memType = getMemRefType(); 3650 3651 if (valueVType.getElementType() != memType.getElementType()) 3652 return emitOpError("base and valueToStore element type should match"); 3653 if (llvm::size(indices()) != memType.getRank()) 3654 return emitOpError("requires ") << memType.getRank() << " indices"; 3655 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3656 return emitOpError("expected valueToStore dim to match mask dim"); 3657 return success(); 3658 } 3659 3660 namespace { 3661 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { 3662 public: 3663 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern; 3664 LogicalResult matchAndRewrite(MaskedStoreOp store, 3665 PatternRewriter &rewriter) const override { 3666 switch (get1DMaskFormat(store.mask())) { 3667 case MaskFormat::AllTrue: 3668 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3669 store, store.valueToStore(), store.base(), store.indices()); 3670 return success(); 3671 case MaskFormat::AllFalse: 3672 rewriter.eraseOp(store); 3673 return success(); 3674 case MaskFormat::Unknown: 3675 return failure(); 3676 } 3677 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore"); 3678 } 3679 }; 3680 } // namespace 3681 3682 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3683 MLIRContext *context) { 3684 results.add<MaskedStoreFolder>(context); 3685 } 3686 3687 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands, 3688 SmallVectorImpl<OpFoldResult> &results) { 3689 return foldMemRefCast(*this); 3690 } 3691 3692 //===----------------------------------------------------------------------===// 3693 // GatherOp 3694 //===----------------------------------------------------------------------===// 3695 3696 LogicalResult GatherOp::verify() { 3697 VectorType indVType = getIndexVectorType(); 3698 VectorType maskVType = getMaskVectorType(); 3699 VectorType resVType = getVectorType(); 3700 MemRefType memType = getMemRefType(); 3701 3702 if (resVType.getElementType() != memType.getElementType()) 3703 return emitOpError("base and result element type should match"); 3704 if (llvm::size(indices()) != memType.getRank()) 3705 return emitOpError("requires ") << memType.getRank() << " indices"; 3706 if (resVType.getDimSize(0) != indVType.getDimSize(0)) 3707 return emitOpError("expected result dim to match indices dim"); 3708 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3709 return emitOpError("expected result dim to match mask dim"); 3710 if (resVType != getPassThruVectorType()) 3711 return emitOpError("expected pass_thru of same type as result type"); 3712 return success(); 3713 } 3714 3715 namespace { 3716 class GatherFolder final : public OpRewritePattern<GatherOp> { 3717 public: 3718 using OpRewritePattern<GatherOp>::OpRewritePattern; 3719 LogicalResult matchAndRewrite(GatherOp gather, 3720 PatternRewriter &rewriter) const override { 3721 switch (get1DMaskFormat(gather.mask())) { 3722 case MaskFormat::AllTrue: 3723 return failure(); // no unmasked equivalent 3724 case MaskFormat::AllFalse: 3725 rewriter.replaceOp(gather, gather.pass_thru()); 3726 return success(); 3727 case MaskFormat::Unknown: 3728 return failure(); 3729 } 3730 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); 3731 } 3732 }; 3733 } // namespace 3734 3735 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, 3736 MLIRContext *context) { 3737 results.add<GatherFolder>(context); 3738 } 3739 3740 //===----------------------------------------------------------------------===// 3741 // ScatterOp 3742 //===----------------------------------------------------------------------===// 3743 3744 LogicalResult ScatterOp::verify() { 3745 VectorType indVType = getIndexVectorType(); 3746 VectorType maskVType = getMaskVectorType(); 3747 VectorType valueVType = getVectorType(); 3748 MemRefType memType = getMemRefType(); 3749 3750 if (valueVType.getElementType() != memType.getElementType()) 3751 return emitOpError("base and valueToStore element type should match"); 3752 if (llvm::size(indices()) != memType.getRank()) 3753 return emitOpError("requires ") << memType.getRank() << " indices"; 3754 if (valueVType.getDimSize(0) != indVType.getDimSize(0)) 3755 return emitOpError("expected valueToStore dim to match indices dim"); 3756 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3757 return emitOpError("expected valueToStore dim to match mask dim"); 3758 return success(); 3759 } 3760 3761 namespace { 3762 class ScatterFolder final : public OpRewritePattern<ScatterOp> { 3763 public: 3764 using OpRewritePattern<ScatterOp>::OpRewritePattern; 3765 LogicalResult matchAndRewrite(ScatterOp scatter, 3766 PatternRewriter &rewriter) const override { 3767 switch (get1DMaskFormat(scatter.mask())) { 3768 case MaskFormat::AllTrue: 3769 return failure(); // no unmasked equivalent 3770 case MaskFormat::AllFalse: 3771 rewriter.eraseOp(scatter); 3772 return success(); 3773 case MaskFormat::Unknown: 3774 return failure(); 3775 } 3776 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); 3777 } 3778 }; 3779 } // namespace 3780 3781 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, 3782 MLIRContext *context) { 3783 results.add<ScatterFolder>(context); 3784 } 3785 3786 //===----------------------------------------------------------------------===// 3787 // ExpandLoadOp 3788 //===----------------------------------------------------------------------===// 3789 3790 LogicalResult ExpandLoadOp::verify() { 3791 VectorType maskVType = getMaskVectorType(); 3792 VectorType passVType = getPassThruVectorType(); 3793 VectorType resVType = getVectorType(); 3794 MemRefType memType = getMemRefType(); 3795 3796 if (resVType.getElementType() != memType.getElementType()) 3797 return emitOpError("base and result element type should match"); 3798 if (llvm::size(indices()) != memType.getRank()) 3799 return emitOpError("requires ") << memType.getRank() << " indices"; 3800 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3801 return emitOpError("expected result dim to match mask dim"); 3802 if (resVType != passVType) 3803 return emitOpError("expected pass_thru of same type as result type"); 3804 return success(); 3805 } 3806 3807 namespace { 3808 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { 3809 public: 3810 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern; 3811 LogicalResult matchAndRewrite(ExpandLoadOp expand, 3812 PatternRewriter &rewriter) const override { 3813 switch (get1DMaskFormat(expand.mask())) { 3814 case MaskFormat::AllTrue: 3815 rewriter.replaceOpWithNewOp<vector::LoadOp>( 3816 expand, expand.getType(), expand.base(), expand.indices()); 3817 return success(); 3818 case MaskFormat::AllFalse: 3819 rewriter.replaceOp(expand, expand.pass_thru()); 3820 return success(); 3821 case MaskFormat::Unknown: 3822 return failure(); 3823 } 3824 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder"); 3825 } 3826 }; 3827 } // namespace 3828 3829 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3830 MLIRContext *context) { 3831 results.add<ExpandLoadFolder>(context); 3832 } 3833 3834 //===----------------------------------------------------------------------===// 3835 // CompressStoreOp 3836 //===----------------------------------------------------------------------===// 3837 3838 LogicalResult CompressStoreOp::verify() { 3839 VectorType maskVType = getMaskVectorType(); 3840 VectorType valueVType = getVectorType(); 3841 MemRefType memType = getMemRefType(); 3842 3843 if (valueVType.getElementType() != memType.getElementType()) 3844 return emitOpError("base and valueToStore element type should match"); 3845 if (llvm::size(indices()) != memType.getRank()) 3846 return emitOpError("requires ") << memType.getRank() << " indices"; 3847 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3848 return emitOpError("expected valueToStore dim to match mask dim"); 3849 return success(); 3850 } 3851 3852 namespace { 3853 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { 3854 public: 3855 using OpRewritePattern<CompressStoreOp>::OpRewritePattern; 3856 LogicalResult matchAndRewrite(CompressStoreOp compress, 3857 PatternRewriter &rewriter) const override { 3858 switch (get1DMaskFormat(compress.mask())) { 3859 case MaskFormat::AllTrue: 3860 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3861 compress, compress.valueToStore(), compress.base(), 3862 compress.indices()); 3863 return success(); 3864 case MaskFormat::AllFalse: 3865 rewriter.eraseOp(compress); 3866 return success(); 3867 case MaskFormat::Unknown: 3868 return failure(); 3869 } 3870 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder"); 3871 } 3872 }; 3873 } // namespace 3874 3875 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3876 MLIRContext *context) { 3877 results.add<CompressStoreFolder>(context); 3878 } 3879 3880 //===----------------------------------------------------------------------===// 3881 // ShapeCastOp 3882 //===----------------------------------------------------------------------===// 3883 3884 /// Returns true if each element of 'a' is equal to the product of a contiguous 3885 /// sequence of the elements of 'b'. Returns false otherwise. 3886 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 3887 unsigned rankA = a.size(); 3888 unsigned rankB = b.size(); 3889 assert(rankA < rankB); 3890 3891 unsigned i = 0; 3892 unsigned j = 0; 3893 while (i < rankA && j < rankB) { 3894 int64_t dimA = a[i]; 3895 int64_t dimB = 1; 3896 while (dimB < dimA && j < rankB) 3897 dimB *= b[j++]; 3898 if (dimA != dimB) 3899 break; 3900 ++i; 3901 3902 // Handle the case when trailing dimensions are of size 1. 3903 // Include them into the contiguous sequence. 3904 auto isOne = [](int64_t v) { return v == 1; }; 3905 if (i < rankA && llvm::all_of(a.slice(i), isOne)) 3906 i = rankA; 3907 if (j < rankB && llvm::all_of(b.slice(j), isOne)) 3908 j = rankB; 3909 } 3910 3911 return i == rankA && j == rankB; 3912 } 3913 3914 static LogicalResult verifyVectorShapeCast(Operation *op, 3915 VectorType sourceVectorType, 3916 VectorType resultVectorType) { 3917 // Check that element type is the same. 3918 if (sourceVectorType.getElementType() != resultVectorType.getElementType()) 3919 return op->emitOpError("source/result vectors must have same element type"); 3920 auto sourceShape = sourceVectorType.getShape(); 3921 auto resultShape = resultVectorType.getShape(); 3922 3923 // Check that product of source dim sizes matches product of result dim sizes. 3924 int64_t sourceDimProduct = std::accumulate( 3925 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); 3926 int64_t resultDimProduct = std::accumulate( 3927 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); 3928 if (sourceDimProduct != resultDimProduct) 3929 return op->emitOpError("source/result number of elements must match"); 3930 3931 // Check that expanding/contracting rank cases. 3932 unsigned sourceRank = sourceVectorType.getRank(); 3933 unsigned resultRank = resultVectorType.getRank(); 3934 if (sourceRank < resultRank) { 3935 if (!isValidShapeCast(sourceShape, resultShape)) 3936 return op->emitOpError("invalid shape cast"); 3937 } else if (sourceRank > resultRank) { 3938 if (!isValidShapeCast(resultShape, sourceShape)) 3939 return op->emitOpError("invalid shape cast"); 3940 } 3941 return success(); 3942 } 3943 3944 LogicalResult ShapeCastOp::verify() { 3945 auto sourceVectorType = source().getType().dyn_cast_or_null<VectorType>(); 3946 auto resultVectorType = result().getType().dyn_cast_or_null<VectorType>(); 3947 3948 // Check if source/result are of vector type. 3949 if (sourceVectorType && resultVectorType) 3950 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); 3951 3952 return success(); 3953 } 3954 3955 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { 3956 // Nop shape cast. 3957 if (source().getType() == result().getType()) 3958 return source(); 3959 3960 // Canceling shape casts. 3961 if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) { 3962 if (result().getType() == otherOp.source().getType()) 3963 return otherOp.source(); 3964 3965 // Only allows valid transitive folding. 3966 VectorType srcType = otherOp.source().getType().cast<VectorType>(); 3967 VectorType resultType = getResult().getType().cast<VectorType>(); 3968 if (srcType.getRank() < resultType.getRank()) { 3969 if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) 3970 return {}; 3971 } else if (srcType.getRank() > resultType.getRank()) { 3972 if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) 3973 return {}; 3974 } else { 3975 return {}; 3976 } 3977 3978 setOperand(otherOp.source()); 3979 return getResult(); 3980 } 3981 return {}; 3982 } 3983 3984 namespace { 3985 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. 3986 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { 3987 public: 3988 using OpRewritePattern<ShapeCastOp>::OpRewritePattern; 3989 3990 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, 3991 PatternRewriter &rewriter) const override { 3992 auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>(); 3993 if (!constantOp) 3994 return failure(); 3995 // Only handle splat for now. 3996 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 3997 if (!dense) 3998 return failure(); 3999 auto newAttr = 4000 DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(), 4001 dense.getSplatValue<Attribute>()); 4002 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); 4003 return success(); 4004 } 4005 }; 4006 4007 } // namespace 4008 4009 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 4010 MLIRContext *context) { 4011 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. 4012 results.add<ShapeCastConstantFolder>(context); 4013 } 4014 4015 //===----------------------------------------------------------------------===// 4016 // VectorBitCastOp 4017 //===----------------------------------------------------------------------===// 4018 4019 LogicalResult BitCastOp::verify() { 4020 auto sourceVectorType = getSourceVectorType(); 4021 auto resultVectorType = getResultVectorType(); 4022 4023 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { 4024 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) 4025 return emitOpError("dimension size mismatch at: ") << i; 4026 } 4027 4028 DataLayout dataLayout = DataLayout::closest(*this); 4029 auto sourceElementBits = 4030 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); 4031 auto resultElementBits = 4032 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); 4033 4034 if (sourceVectorType.getRank() == 0) { 4035 if (sourceElementBits != resultElementBits) 4036 return emitOpError("source/result bitwidth of the 0-D vector element " 4037 "types must be equal"); 4038 } else if (sourceElementBits * sourceVectorType.getShape().back() != 4039 resultElementBits * resultVectorType.getShape().back()) { 4040 return emitOpError( 4041 "source/result bitwidth of the minor 1-D vectors must be equal"); 4042 } 4043 4044 return success(); 4045 } 4046 4047 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) { 4048 // Nop cast. 4049 if (source().getType() == result().getType()) 4050 return source(); 4051 4052 // Canceling bitcasts. 4053 if (auto otherOp = source().getDefiningOp<BitCastOp>()) 4054 if (result().getType() == otherOp.source().getType()) 4055 return otherOp.source(); 4056 4057 Attribute sourceConstant = operands.front(); 4058 if (!sourceConstant) 4059 return {}; 4060 4061 Type srcElemType = getSourceVectorType().getElementType(); 4062 Type dstElemType = getResultVectorType().getElementType(); 4063 4064 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) { 4065 if (floatPack.isSplat()) { 4066 auto splat = floatPack.getSplatValue<FloatAttr>(); 4067 4068 // Casting fp16 into fp32. 4069 if (srcElemType.isF16() && dstElemType.isF32()) { 4070 uint32_t bits = static_cast<uint32_t>( 4071 splat.getValue().bitcastToAPInt().getZExtValue()); 4072 // Duplicate the 16-bit pattern. 4073 bits = (bits << 16) | (bits & 0xffff); 4074 APInt intBits(32, bits); 4075 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); 4076 return DenseElementsAttr::get(getResultVectorType(), floatBits); 4077 } 4078 } 4079 } 4080 4081 return {}; 4082 } 4083 4084 //===----------------------------------------------------------------------===// 4085 // TypeCastOp 4086 //===----------------------------------------------------------------------===// 4087 4088 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { 4089 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>(); 4090 SmallVector<int64_t, 8> res(memRefType.getShape().begin(), 4091 memRefType.getShape().end()); 4092 if (vectorType) 4093 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); 4094 return res; 4095 } 4096 4097 /// Build the canonical memRefType with a single vector. 4098 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. 4099 void TypeCastOp::build(OpBuilder &builder, OperationState &result, 4100 Value source) { 4101 result.addOperands(source); 4102 MemRefType memRefType = source.getType().cast<MemRefType>(); 4103 VectorType vectorType = 4104 VectorType::get(extractShape(memRefType), 4105 getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); 4106 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), 4107 memRefType.getMemorySpace())); 4108 } 4109 4110 LogicalResult TypeCastOp::verify() { 4111 MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); 4112 if (!canonicalType.getLayout().isIdentity()) 4113 return emitOpError("expects operand to be a memref with identity layout"); 4114 if (!getResultMemRefType().getLayout().isIdentity()) 4115 return emitOpError("expects result to be a memref with identity layout"); 4116 if (getResultMemRefType().getMemorySpace() != 4117 getMemRefType().getMemorySpace()) 4118 return emitOpError("expects result in same memory space"); 4119 4120 auto sourceType = getMemRefType(); 4121 auto resultType = getResultMemRefType(); 4122 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != 4123 getElementTypeOrSelf(getElementTypeOrSelf(resultType))) 4124 return emitOpError( 4125 "expects result and operand with same underlying scalar type: ") 4126 << resultType; 4127 if (extractShape(sourceType) != extractShape(resultType)) 4128 return emitOpError( 4129 "expects concatenated result and operand shapes to be equal: ") 4130 << resultType; 4131 return success(); 4132 } 4133 4134 //===----------------------------------------------------------------------===// 4135 // TransposeOp 4136 //===----------------------------------------------------------------------===// 4137 4138 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, 4139 Value vector, ArrayRef<int64_t> transp) { 4140 VectorType vt = vector.getType().cast<VectorType>(); 4141 SmallVector<int64_t, 4> transposedShape(vt.getRank()); 4142 for (unsigned i = 0; i < transp.size(); ++i) 4143 transposedShape[i] = vt.getShape()[transp[i]]; 4144 4145 result.addOperands(vector); 4146 result.addTypes(VectorType::get(transposedShape, vt.getElementType())); 4147 result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); 4148 } 4149 4150 // Eliminates transpose operations, which produce values identical to their 4151 // input values. This happens when the dimensions of the input vector remain in 4152 // their original order after the transpose operation. 4153 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) { 4154 SmallVector<int64_t, 4> transp; 4155 getTransp(transp); 4156 4157 // Check if the permutation of the dimensions contains sequential values: 4158 // {0, 1, 2, ...}. 4159 for (int64_t i = 0, e = transp.size(); i < e; i++) { 4160 if (transp[i] != i) 4161 return {}; 4162 } 4163 4164 return vector(); 4165 } 4166 4167 LogicalResult vector::TransposeOp::verify() { 4168 VectorType vectorType = getVectorType(); 4169 VectorType resultType = getResultType(); 4170 int64_t rank = resultType.getRank(); 4171 if (vectorType.getRank() != rank) 4172 return emitOpError("vector result rank mismatch: ") << rank; 4173 // Verify transposition array. 4174 auto transpAttr = transp().getValue(); 4175 int64_t size = transpAttr.size(); 4176 if (rank != size) 4177 return emitOpError("transposition length mismatch: ") << size; 4178 SmallVector<bool, 8> seen(rank, false); 4179 for (const auto &ta : llvm::enumerate(transpAttr)) { 4180 int64_t i = ta.value().cast<IntegerAttr>().getInt(); 4181 if (i < 0 || i >= rank) 4182 return emitOpError("transposition index out of range: ") << i; 4183 if (seen[i]) 4184 return emitOpError("duplicate position index: ") << i; 4185 seen[i] = true; 4186 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) 4187 return emitOpError("dimension size mismatch at: ") << i; 4188 } 4189 return success(); 4190 } 4191 4192 namespace { 4193 4194 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. 4195 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { 4196 public: 4197 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 4198 4199 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 4200 PatternRewriter &rewriter) const override { 4201 // Wrapper around vector::TransposeOp::getTransp() for cleaner code. 4202 auto getPermutation = [](vector::TransposeOp transpose) { 4203 SmallVector<int64_t, 4> permutation; 4204 transpose.getTransp(permutation); 4205 return permutation; 4206 }; 4207 4208 // Composes two permutations: result[i] = permutation1[permutation2[i]]. 4209 auto composePermutations = [](ArrayRef<int64_t> permutation1, 4210 ArrayRef<int64_t> permutation2) { 4211 SmallVector<int64_t, 4> result; 4212 for (auto index : permutation2) 4213 result.push_back(permutation1[index]); 4214 return result; 4215 }; 4216 4217 // Return if the input of 'transposeOp' is not defined by another transpose. 4218 vector::TransposeOp parentTransposeOp = 4219 transposeOp.vector().getDefiningOp<vector::TransposeOp>(); 4220 if (!parentTransposeOp) 4221 return failure(); 4222 4223 SmallVector<int64_t, 4> permutation = composePermutations( 4224 getPermutation(parentTransposeOp), getPermutation(transposeOp)); 4225 // Replace 'transposeOp' with a new transpose operation. 4226 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 4227 transposeOp, transposeOp.getResult().getType(), 4228 parentTransposeOp.vector(), 4229 vector::getVectorSubscriptAttr(rewriter, permutation)); 4230 return success(); 4231 } 4232 }; 4233 4234 } // namespace 4235 4236 void vector::TransposeOp::getCanonicalizationPatterns( 4237 RewritePatternSet &results, MLIRContext *context) { 4238 results.add<TransposeFolder>(context); 4239 } 4240 4241 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) { 4242 populateFromInt64AttrArray(transp(), results); 4243 } 4244 4245 //===----------------------------------------------------------------------===// 4246 // ConstantMaskOp 4247 //===----------------------------------------------------------------------===// 4248 4249 LogicalResult ConstantMaskOp::verify() { 4250 auto resultType = getResult().getType().cast<VectorType>(); 4251 // Check the corner case of 0-D vectors first. 4252 if (resultType.getRank() == 0) { 4253 if (mask_dim_sizes().size() != 1) 4254 return emitError("array attr must have length 1 for 0-D vectors"); 4255 auto dim = mask_dim_sizes()[0].cast<IntegerAttr>().getInt(); 4256 if (dim != 0 && dim != 1) 4257 return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); 4258 return success(); 4259 } 4260 4261 // Verify that array attr size matches the rank of the vector result. 4262 if (static_cast<int64_t>(mask_dim_sizes().size()) != resultType.getRank()) 4263 return emitOpError( 4264 "must specify array attr of size equal vector result rank"); 4265 // Verify that each array attr element is in bounds of corresponding vector 4266 // result dimension size. 4267 auto resultShape = resultType.getShape(); 4268 SmallVector<int64_t, 4> maskDimSizes; 4269 for (const auto &it : llvm::enumerate(mask_dim_sizes())) { 4270 int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); 4271 if (attrValue < 0 || attrValue > resultShape[it.index()]) 4272 return emitOpError( 4273 "array attr of size out of bounds of vector result dimension size"); 4274 maskDimSizes.push_back(attrValue); 4275 } 4276 // Verify that if one mask dim size is zero, they all should be zero (because 4277 // the mask region is a conjunction of each mask dimension interval). 4278 bool anyZeros = llvm::is_contained(maskDimSizes, 0); 4279 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); 4280 if (anyZeros && !allZeros) 4281 return emitOpError("expected all mask dim sizes to be zeros, " 4282 "as a result of conjunction with zero mask dim"); 4283 return success(); 4284 } 4285 4286 //===----------------------------------------------------------------------===// 4287 // CreateMaskOp 4288 //===----------------------------------------------------------------------===// 4289 4290 LogicalResult CreateMaskOp::verify() { 4291 auto vectorType = getResult().getType().cast<VectorType>(); 4292 // Verify that an operand was specified for each result vector each dimension. 4293 if (vectorType.getRank() == 0) { 4294 if (getNumOperands() != 1) 4295 return emitOpError( 4296 "must specify exactly one operand for 0-D create_mask"); 4297 } else if (getNumOperands() != 4298 getResult().getType().cast<VectorType>().getRank()) { 4299 return emitOpError( 4300 "must specify an operand for each result vector dimension"); 4301 } 4302 return success(); 4303 } 4304 4305 namespace { 4306 4307 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. 4308 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { 4309 public: 4310 using OpRewritePattern<CreateMaskOp>::OpRewritePattern; 4311 4312 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, 4313 PatternRewriter &rewriter) const override { 4314 // Return if any of 'createMaskOp' operands are not defined by a constant. 4315 auto isNotDefByConstant = [](Value operand) { 4316 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 4317 }; 4318 if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant)) 4319 return failure(); 4320 // Gather constant mask dimension sizes. 4321 SmallVector<int64_t, 4> maskDimSizes; 4322 for (auto it : llvm::zip(createMaskOp.operands(), 4323 createMaskOp.getType().getShape())) { 4324 auto *defOp = std::get<0>(it).getDefiningOp(); 4325 int64_t maxDimSize = std::get<1>(it); 4326 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value(); 4327 dimSize = std::min(dimSize, maxDimSize); 4328 // If one of dim sizes is zero, set all dims to zero. 4329 if (dimSize <= 0) { 4330 maskDimSizes.assign(createMaskOp.getType().getRank(), 0); 4331 break; 4332 } 4333 maskDimSizes.push_back(dimSize); 4334 } 4335 // Replace 'createMaskOp' with ConstantMaskOp. 4336 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 4337 createMaskOp, createMaskOp.getResult().getType(), 4338 vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); 4339 return success(); 4340 } 4341 }; 4342 4343 } // namespace 4344 4345 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, 4346 MLIRContext *context) { 4347 results.add<CreateMaskFolder>(context); 4348 } 4349 4350 //===----------------------------------------------------------------------===// 4351 // ScanOp 4352 //===----------------------------------------------------------------------===// 4353 4354 LogicalResult ScanOp::verify() { 4355 VectorType srcType = getSourceType(); 4356 VectorType initialType = getInitialValueType(); 4357 // Check reduction dimension < rank. 4358 int64_t srcRank = srcType.getRank(); 4359 int64_t reductionDim = reduction_dim(); 4360 if (reductionDim >= srcRank) 4361 return emitOpError("reduction dimension ") 4362 << reductionDim << " has to be less than " << srcRank; 4363 4364 // Check that rank(initial_value) = rank(src) - 1. 4365 int64_t initialValueRank = initialType.getRank(); 4366 if (initialValueRank != srcRank - 1) 4367 return emitOpError("initial value rank ") 4368 << initialValueRank << " has to be equal to " << srcRank - 1; 4369 4370 // Check shapes of initial value and src. 4371 ArrayRef<int64_t> srcShape = srcType.getShape(); 4372 ArrayRef<int64_t> initialValueShapes = initialType.getShape(); 4373 SmallVector<int64_t> expectedShape; 4374 for (int i = 0; i < srcRank; i++) { 4375 if (i != reductionDim) 4376 expectedShape.push_back(srcShape[i]); 4377 } 4378 if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape), 4379 [](std::tuple<int64_t, int64_t> s) { 4380 return std::get<0>(s) != std::get<1>(s); 4381 })) { 4382 return emitOpError("incompatible input/initial value shapes"); 4383 } 4384 4385 return success(); 4386 } 4387 4388 void mlir::vector::populateVectorToVectorCanonicalizationPatterns( 4389 RewritePatternSet &patterns) { 4390 patterns 4391 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, 4392 ScatterFolder, ExpandLoadFolder, CompressStoreFolder, 4393 StridedSliceConstantMaskFolder, TransposeFolder>( 4394 patterns.getContext()); 4395 } 4396 4397 //===----------------------------------------------------------------------===// 4398 // SplatOp 4399 //===----------------------------------------------------------------------===// 4400 4401 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { 4402 auto constOperand = operands.front(); 4403 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>()) 4404 return {}; 4405 4406 // SplatElementsAttr::get treats single value for second arg as being a splat. 4407 return SplatElementsAttr::get(getType(), {constOperand}); 4408 } 4409 4410 //===----------------------------------------------------------------------===// 4411 // TableGen'd op method definitions 4412 //===----------------------------------------------------------------------===// 4413 4414 #define GET_OP_CLASSES 4415 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 4416