1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements convenience types for working with super-vectorization 10 // operations, in particular super-vector loads and stores. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/BlockAndValueMapping.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/BuiltinTypes.h" 29 #include "mlir/IR/DialectImplementation.h" 30 #include "mlir/IR/OpImplementation.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Support/LLVM.h" 34 #include "mlir/Support/MathExtras.h" 35 #include "llvm/ADT/StringSet.h" 36 #include "llvm/ADT/bit.h" 37 #include <numeric> 38 39 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" 40 // Pull in all enum type and utility function definitions. 41 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" 42 43 using namespace mlir; 44 using namespace mlir::vector; 45 46 /// Helper enum to classify mask value. 47 enum class MaskFormat { 48 AllTrue = 0, 49 AllFalse = 1, 50 Unknown = 2, 51 }; 52 53 /// Helper method to classify a 1-D mask value. Currently, the method 54 /// looks "under the hood" of a constant value with dense attributes 55 /// and a constant mask operation (since the client may be called at 56 /// various stages during progressive lowering). 57 static MaskFormat get1DMaskFormat(Value mask) { 58 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) { 59 // Inspect constant dense values. We count up for bits that 60 // are set, count down for bits that are cleared, and bail 61 // when a mix is detected. 62 if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) { 63 int64_t val = 0; 64 for (bool b : denseElts.getValues<bool>()) 65 if (b && val >= 0) 66 val++; 67 else if (!b && val <= 0) 68 val--; 69 else 70 return MaskFormat::Unknown; 71 if (val > 0) 72 return MaskFormat::AllTrue; 73 if (val < 0) 74 return MaskFormat::AllFalse; 75 } 76 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { 77 // Inspect constant mask index. If the index exceeds the 78 // dimension size, all bits are set. If the index is zero 79 // or less, no bits are set. 80 ArrayAttr masks = m.mask_dim_sizes(); 81 assert(masks.size() == 1); 82 int64_t i = masks[0].cast<IntegerAttr>().getInt(); 83 int64_t u = m.getType().getDimSize(0); 84 if (i >= u) 85 return MaskFormat::AllTrue; 86 if (i <= 0) 87 return MaskFormat::AllFalse; 88 } 89 return MaskFormat::Unknown; 90 } 91 92 // Helper for verifying combining kinds in contractions and reductions. 93 static bool isSupportedCombiningKind(CombiningKind combiningKind, 94 Type elementType) { 95 switch (combiningKind) { 96 case CombiningKind::ADD: 97 case CombiningKind::MUL: 98 return elementType.isIntOrIndexOrFloat(); 99 case CombiningKind::MINUI: 100 case CombiningKind::MINSI: 101 case CombiningKind::MAXUI: 102 case CombiningKind::MAXSI: 103 case CombiningKind::AND: 104 case CombiningKind::OR: 105 case CombiningKind::XOR: 106 return elementType.isIntOrIndex(); 107 case CombiningKind::MINF: 108 case CombiningKind::MAXF: 109 return elementType.isa<FloatType>(); 110 } 111 return false; 112 } 113 114 /// Return true if the last dimension of the MemRefType has unit stride. Also 115 /// return true for memrefs with no strides. 116 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { 117 int64_t offset; 118 SmallVector<int64_t> strides; 119 auto successStrides = getStridesAndOffset(type, strides, offset); 120 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 121 } 122 123 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, 124 VectorType vectorType) { 125 int64_t elementVectorRank = 0; 126 VectorType elementVectorType = 127 shapedType.getElementType().dyn_cast<VectorType>(); 128 if (elementVectorType) 129 elementVectorRank += elementVectorType.getRank(); 130 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>. 131 // TODO: replace once we have 0-d vectors. 132 if (shapedType.getRank() == 0 && 133 vectorType.getShape() == ArrayRef<int64_t>{1}) 134 return AffineMap::get( 135 /*numDims=*/0, /*numSymbols=*/0, 136 getAffineConstantExpr(0, shapedType.getContext())); 137 return AffineMap::getMinorIdentityMap( 138 shapedType.getRank(), vectorType.getRank() - elementVectorRank, 139 shapedType.getContext()); 140 } 141 142 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, 143 vector::TransferReadOp read) { 144 return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && 145 defWrite.indices() == read.indices() && 146 defWrite.getVectorType() == read.getVectorType() && 147 defWrite.permutation_map() == read.permutation_map(); 148 } 149 150 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, 151 vector::TransferWriteOp priorWrite) { 152 return priorWrite.indices() == write.indices() && 153 priorWrite.mask() == write.mask() && 154 priorWrite.getVectorType() == write.getVectorType() && 155 priorWrite.permutation_map() == write.permutation_map(); 156 } 157 158 bool mlir::vector::isDisjointTransferIndices( 159 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { 160 // For simplicity only look at transfer of same type. 161 if (transferA.getVectorType() != transferB.getVectorType()) 162 return false; 163 unsigned rankOffset = transferA.getLeadingShapedRank(); 164 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { 165 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>(); 166 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>(); 167 // If any of the indices are dynamic we cannot prove anything. 168 if (!indexA || !indexB) 169 continue; 170 171 if (i < rankOffset) { 172 // For leading dimensions, if we can prove that index are different we 173 // know we are accessing disjoint slices. 174 if (indexA.getValue().cast<IntegerAttr>().getInt() != 175 indexB.getValue().cast<IntegerAttr>().getInt()) 176 return true; 177 } else { 178 // For this dimension, we slice a part of the memref we need to make sure 179 // the intervals accessed don't overlap. 180 int64_t distance = 181 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() - 182 indexB.getValue().cast<IntegerAttr>().getInt()); 183 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) 184 return true; 185 } 186 } 187 return false; 188 } 189 190 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, 191 VectorTransferOpInterface transferB) { 192 if (transferA.source() != transferB.source()) 193 return false; 194 return isDisjointTransferIndices(transferA, transferB); 195 } 196 197 //===----------------------------------------------------------------------===// 198 // CombiningKindAttr 199 //===----------------------------------------------------------------------===// 200 201 namespace mlir { 202 namespace vector { 203 namespace detail { 204 struct BitmaskEnumStorage : public AttributeStorage { 205 using KeyTy = uint64_t; 206 207 BitmaskEnumStorage(KeyTy val) : value(val) {} 208 209 bool operator==(const KeyTy &key) const { return value == key; } 210 211 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, 212 const KeyTy &key) { 213 return new (allocator.allocate<BitmaskEnumStorage>()) 214 BitmaskEnumStorage(key); 215 } 216 217 KeyTy value = 0; 218 }; 219 } // namespace detail 220 } // namespace vector 221 } // namespace mlir 222 223 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind, 224 MLIRContext *context) { 225 return Base::get(context, static_cast<uint64_t>(kind)); 226 } 227 228 CombiningKind CombiningKindAttr::getKind() const { 229 return static_cast<CombiningKind>(getImpl()->value); 230 } 231 232 static constexpr const CombiningKind combiningKindsList[] = { 233 // clang-format off 234 CombiningKind::ADD, 235 CombiningKind::MUL, 236 CombiningKind::MINUI, 237 CombiningKind::MINSI, 238 CombiningKind::MINF, 239 CombiningKind::MAXUI, 240 CombiningKind::MAXSI, 241 CombiningKind::MAXF, 242 CombiningKind::AND, 243 CombiningKind::OR, 244 CombiningKind::XOR, 245 // clang-format on 246 }; 247 248 void CombiningKindAttr::print(AsmPrinter &printer) const { 249 printer << "<"; 250 auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) { 251 return bitEnumContains(this->getKind(), kind); 252 }); 253 llvm::interleaveComma(kinds, printer, 254 [&](auto kind) { printer << stringifyEnum(kind); }); 255 printer << ">"; 256 } 257 258 Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) { 259 if (failed(parser.parseLess())) 260 return {}; 261 262 StringRef elemName; 263 if (failed(parser.parseKeyword(&elemName))) 264 return {}; 265 266 auto kind = symbolizeCombiningKind(elemName); 267 if (!kind) { 268 parser.emitError(parser.getNameLoc(), "Unknown combining kind: ") 269 << elemName; 270 return {}; 271 } 272 273 if (failed(parser.parseGreater())) 274 return {}; 275 276 return CombiningKindAttr::get(kind.getValue(), parser.getContext()); 277 } 278 279 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser, 280 Type type) const { 281 StringRef attrKind; 282 if (parser.parseKeyword(&attrKind)) 283 return {}; 284 285 if (attrKind == "kind") 286 return CombiningKindAttr::parse(parser, {}); 287 288 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; 289 return {}; 290 } 291 292 void VectorDialect::printAttribute(Attribute attr, 293 DialectAsmPrinter &os) const { 294 if (auto ck = attr.dyn_cast<CombiningKindAttr>()) { 295 os << "kind"; 296 ck.print(os); 297 return; 298 } 299 llvm_unreachable("Unknown attribute type"); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // VectorDialect 304 //===----------------------------------------------------------------------===// 305 306 void VectorDialect::initialize() { 307 addAttributes<CombiningKindAttr>(); 308 309 addOperations< 310 #define GET_OP_LIST 311 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 312 >(); 313 } 314 315 /// Materialize a single constant operation from a given attribute value with 316 /// the desired resultant type. 317 Operation *VectorDialect::materializeConstant(OpBuilder &builder, 318 Attribute value, Type type, 319 Location loc) { 320 return builder.create<arith::ConstantOp>(loc, type, value); 321 } 322 323 IntegerType vector::getVectorSubscriptType(Builder &builder) { 324 return builder.getIntegerType(64); 325 } 326 327 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, 328 ArrayRef<int64_t> values) { 329 return builder.getI64ArrayAttr(values); 330 } 331 332 //===----------------------------------------------------------------------===// 333 // MultiDimReductionOp 334 //===----------------------------------------------------------------------===// 335 336 void vector::MultiDimReductionOp::build(OpBuilder &builder, 337 OperationState &result, Value source, 338 ArrayRef<bool> reductionMask, 339 CombiningKind kind) { 340 result.addOperands(source); 341 auto sourceVectorType = source.getType().cast<VectorType>(); 342 auto targetType = MultiDimReductionOp::inferDestType( 343 sourceVectorType.getShape(), reductionMask, 344 sourceVectorType.getElementType()); 345 result.addTypes(targetType); 346 347 SmallVector<int64_t> reductionDims; 348 for (const auto &en : llvm::enumerate(reductionMask)) 349 if (en.value()) 350 reductionDims.push_back(en.index()); 351 result.addAttribute(getReductionDimsAttrName(), 352 builder.getI64ArrayAttr(reductionDims)); 353 result.addAttribute(getKindAttrName(), 354 CombiningKindAttr::get(kind, builder.getContext())); 355 } 356 357 static LogicalResult verify(MultiDimReductionOp op) { 358 auto reductionMask = op.getReductionMask(); 359 auto targetType = MultiDimReductionOp::inferDestType( 360 op.getSourceVectorType().getShape(), reductionMask, 361 op.getSourceVectorType().getElementType()); 362 // TODO: update to support 0-d vectors when available. 363 if (targetType != op.getDestType()) 364 return op.emitError("invalid output vector type: ") 365 << op.getDestType() << " (expected: " << targetType << ")"; 366 return success(); 367 } 368 369 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) { 370 // Single parallel dim, this is a noop. 371 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) 372 return source(); 373 return {}; 374 } 375 376 //===----------------------------------------------------------------------===// 377 // ReductionOp 378 //===----------------------------------------------------------------------===// 379 380 static LogicalResult verify(ReductionOp op) { 381 // Verify for 1-D vector. 382 int64_t rank = op.getVectorType().getRank(); 383 if (rank != 1) 384 return op.emitOpError("unsupported reduction rank: ") << rank; 385 386 // Verify supported reduction kind. 387 StringRef strKind = op.kind(); 388 auto maybeKind = symbolizeCombiningKind(strKind); 389 if (!maybeKind) 390 return op.emitOpError("unknown reduction kind: ") << strKind; 391 392 Type eltType = op.dest().getType(); 393 if (!isSupportedCombiningKind(*maybeKind, eltType)) 394 return op.emitOpError("unsupported reduction type '") 395 << eltType << "' for kind '" << op.kind() << "'"; 396 397 // Verify optional accumulator. 398 if (!op.acc().empty()) { 399 if (strKind != "add" && strKind != "mul") 400 return op.emitOpError("no accumulator for reduction kind: ") << strKind; 401 if (!eltType.isa<FloatType>()) 402 return op.emitOpError("no accumulator for type: ") << eltType; 403 } 404 405 return success(); 406 } 407 408 static ParseResult parseReductionOp(OpAsmParser &parser, 409 OperationState &result) { 410 SmallVector<OpAsmParser::OperandType, 2> operandsInfo; 411 Type redType; 412 Type resType; 413 Attribute attr; 414 if (parser.parseAttribute(attr, "kind", result.attributes) || 415 parser.parseComma() || parser.parseOperandList(operandsInfo) || 416 parser.parseColonType(redType) || 417 parser.parseKeywordType("into", resType) || 418 (!operandsInfo.empty() && 419 parser.resolveOperand(operandsInfo[0], redType, result.operands)) || 420 (operandsInfo.size() > 1 && 421 parser.resolveOperand(operandsInfo[1], resType, result.operands)) || 422 parser.addTypeToList(resType, result.types)) 423 return failure(); 424 if (operandsInfo.empty() || operandsInfo.size() > 2) 425 return parser.emitError(parser.getNameLoc(), 426 "unsupported number of operands"); 427 return success(); 428 } 429 430 static void print(OpAsmPrinter &p, ReductionOp op) { 431 p << " \"" << op.kind() << "\", " << op.vector(); 432 if (!op.acc().empty()) 433 p << ", " << op.acc(); 434 p << " : " << op.vector().getType() << " into " << op.dest().getType(); 435 } 436 437 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, 438 OpBuilder &builder, Location loc, 439 Value vector) { 440 Type scalarType = vector.getType().cast<ShapedType>().getElementType(); 441 switch (op) { 442 case arith::AtomicRMWKind::addf: 443 case arith::AtomicRMWKind::addi: 444 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 445 builder.getStringAttr("add"), 446 vector, ValueRange{}); 447 case arith::AtomicRMWKind::mulf: 448 case arith::AtomicRMWKind::muli: 449 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 450 builder.getStringAttr("mul"), 451 vector, ValueRange{}); 452 case arith::AtomicRMWKind::minf: 453 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 454 builder.getStringAttr("minf"), 455 vector, ValueRange{}); 456 case arith::AtomicRMWKind::mins: 457 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 458 builder.getStringAttr("minsi"), 459 vector, ValueRange{}); 460 case arith::AtomicRMWKind::minu: 461 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 462 builder.getStringAttr("minui"), 463 vector, ValueRange{}); 464 case arith::AtomicRMWKind::maxf: 465 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 466 builder.getStringAttr("maxf"), 467 vector, ValueRange{}); 468 case arith::AtomicRMWKind::maxs: 469 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 470 builder.getStringAttr("maxsi"), 471 vector, ValueRange{}); 472 case arith::AtomicRMWKind::maxu: 473 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType, 474 builder.getStringAttr("maxui"), 475 vector, ValueRange{}); 476 // TODO: Add remaining reduction operations. 477 default: 478 (void)emitOptionalError(loc, "Reduction operation type not supported"); 479 break; 480 } 481 return nullptr; 482 } 483 484 //===----------------------------------------------------------------------===// 485 // ContractionOp 486 //===----------------------------------------------------------------------===// 487 488 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 489 Value lhs, Value rhs, Value acc, 490 ArrayRef<ArrayRef<AffineExpr>> indexingExprs, 491 ArrayRef<StringRef> iteratorTypes) { 492 result.addOperands({lhs, rhs, acc}); 493 result.addTypes(acc.getType()); 494 result.addAttribute(getIndexingMapsAttrName(), 495 builder.getAffineMapArrayAttr( 496 AffineMap::inferFromExprList(indexingExprs))); 497 result.addAttribute(getIteratorTypesAttrName(), 498 builder.getStrArrayAttr(iteratorTypes)); 499 } 500 501 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, 502 Value lhs, Value rhs, Value acc, 503 ArrayAttr indexingMaps, 504 ArrayAttr iteratorTypes) { 505 result.addOperands({lhs, rhs, acc}); 506 result.addTypes(acc.getType()); 507 result.addAttribute(getIndexingMapsAttrName(), indexingMaps); 508 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); 509 result.addAttribute(ContractionOp::getKindAttrName(), 510 CombiningKindAttr::get(ContractionOp::getDefaultKind(), 511 builder.getContext())); 512 } 513 514 static ParseResult parseContractionOp(OpAsmParser &parser, 515 OperationState &result) { 516 OpAsmParser::OperandType lhsInfo; 517 OpAsmParser::OperandType rhsInfo; 518 OpAsmParser::OperandType accInfo; 519 SmallVector<OpAsmParser::OperandType, 2> masksInfo; 520 SmallVector<Type, 2> types; 521 Type resultType; 522 auto loc = parser.getCurrentLocation(); 523 DictionaryAttr dictAttr; 524 // TODO: Unify linalg op attribute parsing. 525 if (parser.parseAttribute(dictAttr, "_", result.attributes) || 526 parser.parseOperand(lhsInfo) || parser.parseComma() || 527 parser.parseOperand(rhsInfo) || parser.parseComma() || 528 parser.parseOperand(accInfo) || 529 parser.parseTrailingOperandList(masksInfo) || 530 parser.parseOptionalAttrDict(result.attributes) || 531 parser.parseColonTypeList(types) || 532 parser.parseKeywordType("into", resultType) || 533 parser.resolveOperand(lhsInfo, types[0], result.operands) || 534 parser.resolveOperand(rhsInfo, types[1], result.operands) || 535 parser.resolveOperand(accInfo, resultType, result.operands) || 536 parser.addTypeToList(resultType, result.types)) 537 return failure(); 538 result.attributes.assign(dictAttr.getValue().begin(), 539 dictAttr.getValue().end()); 540 if (!result.attributes.get(ContractionOp::getKindAttrName())) { 541 result.addAttribute(ContractionOp::getKindAttrName(), 542 CombiningKindAttr::get(ContractionOp::getDefaultKind(), 543 result.getContext())); 544 } 545 if (masksInfo.empty()) 546 return success(); 547 if (masksInfo.size() != 2) 548 return parser.emitError(parser.getNameLoc(), 549 "expected zero or exactly 2 vector mask operands"); 550 auto lhsType = types[0].cast<VectorType>(); 551 auto rhsType = types[1].cast<VectorType>(); 552 auto maskElementType = parser.getBuilder().getI1Type(); 553 std::array<Type, 2> maskTypes = { 554 VectorType::Builder(lhsType).setElementType(maskElementType), 555 VectorType::Builder(rhsType).setElementType(maskElementType)}; 556 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) 557 return failure(); 558 return success(); 559 } 560 561 static void print(OpAsmPrinter &p, ContractionOp op) { 562 // TODO: Unify printing code with linalg ops. 563 auto attrNames = op.getTraitAttrNames(); 564 llvm::StringSet<> traitAttrsSet; 565 traitAttrsSet.insert(attrNames.begin(), attrNames.end()); 566 SmallVector<NamedAttribute, 8> attrs; 567 for (auto attr : op->getAttrs()) 568 if (traitAttrsSet.count(attr.getName().strref()) > 0) 569 attrs.push_back(attr); 570 571 auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); 572 p << " " << dictAttr << " " << op.lhs() << ", "; 573 p << op.rhs() << ", " << op.acc(); 574 if (op.masks().size() == 2) 575 p << ", " << op.masks(); 576 577 p.printOptionalAttrDict(op->getAttrs(), attrNames); 578 p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into " 579 << op.getResultType(); 580 } 581 582 static bool verifyDimMap(VectorType lhsType, VectorType rhsType, 583 const std::vector<std::pair<int64_t, int64_t>> &map) { 584 for (auto &dimPair : map) { 585 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || 586 dimPair.second < 0 || dimPair.second >= rhsType.getRank() || 587 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) 588 return false; 589 } 590 return true; 591 } 592 593 static LogicalResult verifyOutputShape( 594 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, 595 Type resType, 596 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, 597 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { 598 DenseSet<int64_t> lhsContractingDimSet; 599 DenseSet<int64_t> rhsContractingDimSet; 600 for (auto &dimPair : contractingDimMap) { 601 lhsContractingDimSet.insert(dimPair.first); 602 rhsContractingDimSet.insert(dimPair.second); 603 } 604 DenseSet<int64_t> rhsBatchDimSet; 605 for (auto &dimPair : batchDimMap) 606 rhsBatchDimSet.insert(dimPair.second); 607 608 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. 609 SmallVector<int64_t, 4> expectedResultDims; 610 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { 611 if (lhsContractingDimSet.count(i) > 0) 612 continue; 613 expectedResultDims.push_back(lhsType.getDimSize(i)); 614 } 615 616 // Add free dimensions from 'rhsType' to 'expectedResultDims'. 617 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { 618 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) 619 continue; 620 expectedResultDims.push_back(rhsType.getDimSize(i)); 621 } 622 623 // Verify 'expectedResultDims'. 624 if (expectedResultDims.empty()) { 625 // No batch or free dimension implies a scalar result. 626 if (resType.isa<VectorType>() || accType.isa<VectorType>()) 627 return op.emitOpError("invalid accumulator/result vector shape"); 628 } else { 629 // At least one batch or free dimension implies a vector result. 630 auto resVectorType = resType.dyn_cast<VectorType>(); 631 auto accVectorType = accType.dyn_cast<VectorType>(); 632 if (!resVectorType || !accVectorType) 633 return op.emitOpError("invalid accumulator/result vector shape"); 634 635 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector 636 // types fully define the result vector type. This assumes the affine maps 637 // are well-formed, which must have been verified already. 638 MLIRContext *ctx = op.getContext(); 639 AffineMap lhsMap = op.getIndexingMaps()[0]; 640 AffineMap rhsMap = op.getIndexingMaps()[1]; 641 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); 642 for (auto pair : 643 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { 644 VectorType v = pair.first; 645 auto map = pair.second; 646 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { 647 unsigned pos = map.getDimPosition(idx); 648 if (!extents[pos]) 649 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); 650 } 651 } 652 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) && 653 "expected extent along all dimensions."); 654 655 AffineMap resMap = op.getIndexingMaps()[2]; 656 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), 657 /*symCount=*/0, extents, ctx); 658 // Compose the resMap with the extentsMap, which is a constant map. 659 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); 660 assert(llvm::all_of( 661 expectedMap.getResults(), 662 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && 663 "expected constant extent along all dimensions."); 664 // Extract the expected shape and build the type. 665 auto expectedShape = llvm::to_vector<4>( 666 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { 667 return e.cast<AffineConstantExpr>().getValue(); 668 })); 669 auto expected = 670 VectorType::get(expectedShape, resVectorType.getElementType()); 671 if (resVectorType != expected || accVectorType != expected) 672 return op.emitOpError( 673 "invalid accumulator/result vector shape, expected: ") 674 << expected; 675 } 676 return success(); 677 } 678 679 static LogicalResult verify(ContractionOp op) { 680 auto lhsType = op.getLhsType(); 681 auto rhsType = op.getRhsType(); 682 auto accType = op.getAccType(); 683 auto resType = op.getResultType(); 684 685 // Verify that an indexing map was specified for each vector operand. 686 if (op.indexing_maps().size() != 3) 687 return op.emitOpError("expected an indexing map for each vector operand"); 688 689 // Verify that each index map has 'numIterators' inputs, no symbols, and 690 // that the number of map outputs equals the rank of its associated 691 // vector operand. 692 unsigned numIterators = op.iterator_types().getValue().size(); 693 for (const auto &it : llvm::enumerate(op.indexing_maps())) { 694 auto index = it.index(); 695 auto map = it.value().cast<AffineMapAttr>().getValue(); 696 if (map.getNumSymbols() != 0) 697 return op.emitOpError("expected indexing map ") 698 << index << " to have no symbols"; 699 auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>(); 700 unsigned rank = vectorType ? vectorType.getShape().size() : 0; 701 // Verify that the map has the right number of inputs, outputs, and indices. 702 // This also correctly accounts for (..) -> () for rank-0 results. 703 if (map.getNumDims() != numIterators) 704 return op.emitOpError("expected indexing map ") 705 << index << " to have " << numIterators << " number of inputs"; 706 if (map.getNumResults() != rank) 707 return op.emitOpError("expected indexing map ") 708 << index << " to have " << rank << " number of outputs"; 709 if (!map.isProjectedPermutation()) 710 return op.emitOpError("expected indexing map ") 711 << index << " to be a projected permutation of its inputs"; 712 } 713 714 auto contractingDimMap = op.getContractingDimMap(); 715 auto batchDimMap = op.getBatchDimMap(); 716 717 // Verify at least one contracting dimension pair was specified. 718 if (contractingDimMap.empty()) 719 return op.emitOpError("expected at least one contracting dimension pair"); 720 721 // Verify contracting dimension map was properly constructed. 722 if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) 723 return op.emitOpError("invalid contracting dimension map"); 724 725 // Verify batch dimension map was properly constructed. 726 if (!verifyDimMap(lhsType, rhsType, batchDimMap)) 727 return op.emitOpError("invalid batch dimension map"); 728 729 // Verify 'accType' and 'resType' shape. 730 if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType, 731 contractingDimMap, batchDimMap))) 732 return failure(); 733 734 // Verify that either two vector masks are set or none are set. 735 auto lhsMaskType = op.getLHSVectorMaskType(); 736 auto rhsMaskType = op.getRHSVectorMaskType(); 737 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) 738 return op.emitOpError("invalid number of vector masks specified"); 739 if (lhsMaskType && rhsMaskType) { 740 // Verify mask rank == argument rank. 741 if (lhsMaskType.getShape().size() != lhsType.getShape().size() || 742 rhsMaskType.getShape().size() != rhsType.getShape().size()) 743 return op.emitOpError("invalid vector mask rank"); 744 } 745 746 // Verify supported combining kind. 747 auto vectorType = resType.dyn_cast<VectorType>(); 748 auto elementType = vectorType ? vectorType.getElementType() : resType; 749 if (!isSupportedCombiningKind(op.kind(), elementType)) 750 return op.emitOpError("unsupported contraction type"); 751 752 return success(); 753 } 754 755 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() { 756 static constexpr StringRef names[3] = {getIndexingMapsAttrName(), 757 getIteratorTypesAttrName(), 758 ContractionOp::getKindAttrName()}; 759 return llvm::makeArrayRef(names); 760 } 761 762 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { 763 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) 764 if (targetExpr == map.getResult(i)) 765 return i; 766 return -1; 767 } 768 769 static std::vector<std::pair<int64_t, int64_t>> 770 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, 771 StringRef targetIteratorTypeName, MLIRContext *context) { 772 std::vector<std::pair<int64_t, int64_t>> dimMap; 773 for (const auto &it : llvm::enumerate(iteratorTypes)) { 774 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 775 if (iteratorTypeName != targetIteratorTypeName) 776 continue; 777 // Search lhs/rhs map results for 'targetExpr'. 778 auto targetExpr = getAffineDimExpr(it.index(), context); 779 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); 780 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); 781 if (lhsDim >= 0 && rhsDim >= 0) 782 dimMap.emplace_back(lhsDim, rhsDim); 783 } 784 return dimMap; 785 } 786 787 void ContractionOp::getIterationBounds( 788 SmallVectorImpl<int64_t> &iterationBounds) { 789 auto lhsShape = getLhsType().getShape(); 790 auto resVectorType = getResultType().dyn_cast<VectorType>(); 791 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 792 SmallVector<int64_t, 2> iterationShape; 793 for (const auto &it : llvm::enumerate(iterator_types())) { 794 // Search lhs/rhs map results for 'targetExpr'. 795 auto targetExpr = getAffineDimExpr(it.index(), getContext()); 796 auto iteratorTypeName = it.value().cast<StringAttr>().getValue(); 797 if (iteratorTypeName == getReductionIteratorTypeName()) { 798 // Get reduction dim size from lhs shape (same size in rhsShape). 799 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); 800 assert(lhsDimIndex >= 0); 801 iterationBounds.push_back(lhsShape[lhsDimIndex]); 802 continue; 803 } 804 // Get parallel dimension size from result shape. 805 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); 806 assert(resDimIndex >= 0); 807 assert(resVectorType != nullptr); 808 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); 809 } 810 } 811 812 void ContractionOp::getIterationIndexMap( 813 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { 814 unsigned numMaps = indexing_maps().getValue().size(); 815 iterationIndexMap.resize(numMaps); 816 for (const auto &it : llvm::enumerate(indexing_maps())) { 817 auto index = it.index(); 818 auto map = it.value().cast<AffineMapAttr>().getValue(); 819 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { 820 auto dim = map.getResult(i).cast<AffineDimExpr>(); 821 iterationIndexMap[index][dim.getPosition()] = i; 822 } 823 } 824 } 825 826 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { 827 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 828 return getDimMap(indexingMaps, iterator_types(), 829 getReductionIteratorTypeName(), getContext()); 830 } 831 832 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { 833 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps()); 834 return getDimMap(indexingMaps, iterator_types(), 835 getParallelIteratorTypeName(), getContext()); 836 } 837 838 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { 839 return llvm::to_vector<4>( 840 llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) { 841 return mapAttr.cast<AffineMapAttr>().getValue(); 842 })); 843 } 844 845 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { 846 SmallVector<int64_t, 4> shape; 847 getIterationBounds(shape); 848 return shape; 849 } 850 851 /// Return a fused vector::ContractionOp which represents a patterns such as: 852 /// 853 /// ```mlir 854 /// %c0 = vector.constant 0: ... 855 /// %c = vector.contract %a, %b, %c0: ... 856 /// %e = add %c, %d: ... 857 /// ``` 858 /// 859 /// by: 860 /// 861 /// ```mlir 862 /// %e = vector.contract %a, %b, %d: ... 863 /// ``` 864 /// 865 /// Return null if the canonicalization does not apply. 866 // TODO: This should be a folding of Add into Contract in core but while they 867 // live in different dialects, it is not possible without unnatural 868 // dependencies. 869 template <typename AddOpType> 870 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { 871 using OpRewritePattern<AddOpType>::OpRewritePattern; 872 873 LogicalResult matchAndRewrite(AddOpType addOp, 874 PatternRewriter &rewriter) const override { 875 auto canonicalize = [&](Value maybeContraction, 876 Value otherOperand) -> vector::ContractionOp { 877 vector::ContractionOp contractionOp = 878 dyn_cast_or_null<vector::ContractionOp>( 879 maybeContraction.getDefiningOp()); 880 if (!contractionOp) 881 return vector::ContractionOp(); 882 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>( 883 contractionOp.acc().getDefiningOp())) { 884 if (maybeZero.getValue() == 885 rewriter.getZeroAttr(contractionOp.acc().getType())) { 886 BlockAndValueMapping bvm; 887 bvm.map(contractionOp.acc(), otherOperand); 888 auto newContraction = 889 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm)); 890 rewriter.replaceOp(addOp, newContraction.getResult()); 891 return newContraction; 892 } 893 } 894 return vector::ContractionOp(); 895 }; 896 897 Value a = addOp->getOperand(0), b = addOp->getOperand(1); 898 vector::ContractionOp contract = canonicalize(a, b); 899 contract = contract ? contract : canonicalize(b, a); 900 return contract ? success() : failure(); 901 } 902 }; 903 904 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, 905 MLIRContext *context) { 906 results.add<CanonicalizeContractAdd<arith::AddIOp>, 907 CanonicalizeContractAdd<arith::AddFOp>>(context); 908 } 909 910 //===----------------------------------------------------------------------===// 911 // ExtractElementOp 912 //===----------------------------------------------------------------------===// 913 914 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 915 Value source) { 916 result.addOperands({source}); 917 result.addTypes(source.getType().cast<VectorType>().getElementType()); 918 } 919 920 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, 921 Value source, Value position) { 922 result.addOperands({source, position}); 923 result.addTypes(source.getType().cast<VectorType>().getElementType()); 924 } 925 926 static LogicalResult verify(vector::ExtractElementOp op) { 927 VectorType vectorType = op.getVectorType(); 928 if (vectorType.getRank() == 0) { 929 if (op.position()) 930 return op.emitOpError("expected position to be empty with 0-D vector"); 931 return success(); 932 } 933 if (vectorType.getRank() != 1) 934 return op.emitOpError("unexpected >1 vector rank"); 935 if (!op.position()) 936 return op.emitOpError("expected position for 1-D vector"); 937 return success(); 938 } 939 940 //===----------------------------------------------------------------------===// 941 // ExtractOp 942 //===----------------------------------------------------------------------===// 943 944 static Type inferExtractOpResultType(VectorType vectorType, 945 ArrayAttr position) { 946 if (static_cast<int64_t>(position.size()) == vectorType.getRank()) 947 return vectorType.getElementType(); 948 return VectorType::get(vectorType.getShape().drop_front(position.size()), 949 vectorType.getElementType()); 950 } 951 952 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 953 Value source, ArrayRef<int64_t> position) { 954 result.addOperands(source); 955 auto positionAttr = getVectorSubscriptAttr(builder, position); 956 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(), 957 positionAttr)); 958 result.addAttribute(getPositionAttrName(), positionAttr); 959 } 960 961 // Convenience builder which assumes the values are constant indices. 962 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, 963 Value source, ValueRange position) { 964 SmallVector<int64_t, 4> positionConstants = 965 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 966 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 967 })); 968 build(builder, result, source, positionConstants); 969 } 970 971 static void print(OpAsmPrinter &p, vector::ExtractOp op) { 972 p << " " << op.vector() << op.position(); 973 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 974 p << " : " << op.vector().getType(); 975 } 976 977 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { 978 SMLoc attributeLoc, typeLoc; 979 NamedAttrList attrs; 980 OpAsmParser::OperandType vector; 981 Type type; 982 Attribute attr; 983 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) || 984 parser.parseAttribute(attr, "position", attrs) || 985 parser.parseOptionalAttrDict(attrs) || 986 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type)) 987 return failure(); 988 989 auto vectorType = type.dyn_cast<VectorType>(); 990 if (!vectorType) 991 return parser.emitError(typeLoc, "expected vector type"); 992 993 auto positionAttr = attr.dyn_cast<ArrayAttr>(); 994 if (!positionAttr || 995 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank()) 996 return parser.emitError( 997 attributeLoc, 998 "expected position attribute of rank smaller than vector rank"); 999 1000 Type resType = inferExtractOpResultType(vectorType, positionAttr); 1001 result.attributes = attrs; 1002 return failure(parser.resolveOperand(vector, type, result.operands) || 1003 parser.addTypeToList(resType, result.types)); 1004 } 1005 1006 static LogicalResult verify(vector::ExtractOp op) { 1007 auto positionAttr = op.position().getValue(); 1008 if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank())) 1009 return op.emitOpError( 1010 "expected position attribute of rank smaller than vector rank"); 1011 for (const auto &en : llvm::enumerate(positionAttr)) { 1012 auto attr = en.value().dyn_cast<IntegerAttr>(); 1013 if (!attr || attr.getInt() < 0 || 1014 attr.getInt() >= op.getVectorType().getDimSize(en.index())) 1015 return op.emitOpError("expected position attribute #") 1016 << (en.index() + 1) 1017 << " to be a non-negative integer smaller than the corresponding " 1018 "vector dimension"; 1019 } 1020 return success(); 1021 } 1022 1023 template <typename IntType> 1024 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 1025 return llvm::to_vector<4>(llvm::map_range( 1026 arrayAttr.getAsRange<IntegerAttr>(), 1027 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 1028 } 1029 1030 /// Fold the result of chains of ExtractOp in place by simply concatenating the 1031 /// positions. 1032 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { 1033 if (!extractOp.vector().getDefiningOp<ExtractOp>()) 1034 return failure(); 1035 1036 SmallVector<int64_t, 4> globalPosition; 1037 ExtractOp currentOp = extractOp; 1038 auto extrPos = extractVector<int64_t>(currentOp.position()); 1039 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1040 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) { 1041 currentOp = nextOp; 1042 auto extrPos = extractVector<int64_t>(currentOp.position()); 1043 globalPosition.append(extrPos.rbegin(), extrPos.rend()); 1044 } 1045 extractOp.setOperand(currentOp.vector()); 1046 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1047 OpBuilder b(extractOp.getContext()); 1048 std::reverse(globalPosition.begin(), globalPosition.end()); 1049 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1050 b.getI64ArrayAttr(globalPosition)); 1051 return success(); 1052 } 1053 1054 namespace { 1055 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. 1056 /// Walk back a chain of InsertOp/TransposeOp until we hit a match. 1057 /// Compose TransposeOp permutations as we walk back. 1058 /// This helper class keeps an updated extraction position `extractPosition` 1059 /// with extra trailing sentinels. 1060 /// The sentinels encode the internal transposition status of the result vector. 1061 /// As we iterate, extractPosition is permuted and updated. 1062 class ExtractFromInsertTransposeChainState { 1063 public: 1064 ExtractFromInsertTransposeChainState(ExtractOp e); 1065 1066 /// Iterate over producing insert and transpose ops until we find a fold. 1067 Value fold(); 1068 1069 private: 1070 /// Return true if the vector at position `a` is contained within the vector 1071 /// at position `b`. Under insert/extract semantics, this is the same as `a` 1072 /// is a prefix of `b`. 1073 template <typename ContainerA, typename ContainerB> 1074 bool isContainedWithin(const ContainerA &a, const ContainerB &b) { 1075 return a.size() <= b.size() && 1076 std::equal(a.begin(), a.begin() + a.size(), b.begin()); 1077 } 1078 1079 /// Return true if the vector at position `a` intersects the vector at 1080 /// position `b`. Under insert/extract semantics, this is the same as equality 1081 /// of all entries of `a` that are >=0 with the corresponding entries of b. 1082 /// Comparison is on the common prefix (i.e. zip). 1083 template <typename ContainerA, typename ContainerB> 1084 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { 1085 for (auto it : llvm::zip(a, b)) { 1086 if (std::get<0>(it) < 0 || std::get<0>(it) < 0) 1087 continue; 1088 if (std::get<0>(it) != std::get<1>(it)) 1089 return false; 1090 } 1091 return true; 1092 } 1093 1094 /// Folding is only possible in the absence of an internal permutation in the 1095 /// result vector. 1096 bool canFold() { 1097 return (sentinels == 1098 makeArrayRef(extractPosition).drop_front(extractedRank)); 1099 } 1100 1101 // Helper to get the next defining op of interest. 1102 void updateStateForNextIteration(Value v) { 1103 nextInsertOp = v.getDefiningOp<vector::InsertOp>(); 1104 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>(); 1105 }; 1106 1107 // Case 1. If we hit a transpose, just compose the map and iterate. 1108 // Invariant: insert + transpose do not change rank, we can always compose. 1109 LogicalResult handleTransposeOp(); 1110 1111 // Case 2: the insert position matches extractPosition exactly, early return. 1112 LogicalResult handleInsertOpWithMatchingPos(Value &res); 1113 1114 /// Case 3: if the insert position is a prefix of extractPosition, extract a 1115 /// portion of the source of the insert. 1116 /// Example: 1117 /// ``` 1118 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> 1119 /// // extractPosition == [1, 2, 3] 1120 /// %ext = vector.extract %ins[1, 0]: vector<3x4x5> 1121 /// // can fold to vector.extract %source[0, 3] 1122 /// %ext = vector.extract %source[3]: vector<5x6> 1123 /// ``` 1124 /// To traverse through %source, we need to set the leading dims to 0 and 1125 /// drop the extra leading dims. 1126 /// This method updates the internal state. 1127 LogicalResult handleInsertOpWithPrefixPos(Value &res); 1128 1129 /// Try to fold in place to extract(source, extractPosition) and return the 1130 /// folded result. Return null if folding is not possible (e.g. due to an 1131 /// internal tranposition in the result). 1132 Value tryToFoldExtractOpInPlace(Value source); 1133 1134 ExtractOp extractOp; 1135 int64_t vectorRank; 1136 int64_t extractedRank; 1137 1138 InsertOp nextInsertOp; 1139 TransposeOp nextTransposeOp; 1140 1141 /// Sentinel values that encode the internal permutation status of the result. 1142 /// They are set to (-1, ... , -k) at the beginning and appended to 1143 /// `extractPosition`. 1144 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to 1145 /// ensure that there is no internal transposition. 1146 /// Internal transposition cannot be accounted for with a folding pattern. 1147 // TODO: We could relax the internal transposition with an extra transposition 1148 // operation in a future canonicalizer. 1149 SmallVector<int64_t> sentinels; 1150 SmallVector<int64_t> extractPosition; 1151 }; 1152 } // namespace 1153 1154 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( 1155 ExtractOp e) 1156 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), 1157 extractedRank(extractOp.position().size()) { 1158 assert(vectorRank >= extractedRank && "extracted pos overflow"); 1159 sentinels.reserve(vectorRank - extractedRank); 1160 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) 1161 sentinels.push_back(-(i + 1)); 1162 extractPosition = extractVector<int64_t>(extractOp.position()); 1163 llvm::append_range(extractPosition, sentinels); 1164 } 1165 1166 // Case 1. If we hit a transpose, just compose the map and iterate. 1167 // Invariant: insert + transpose do not change rank, we can always compose. 1168 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { 1169 if (!nextTransposeOp) 1170 return failure(); 1171 auto permutation = extractVector<unsigned>(nextTransposeOp.transp()); 1172 AffineMap m = inversePermutation( 1173 AffineMap::getPermutationMap(permutation, extractOp.getContext())); 1174 extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); 1175 return success(); 1176 } 1177 1178 // Case 2: the insert position matches extractPosition exactly, early return. 1179 LogicalResult 1180 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( 1181 Value &res) { 1182 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1183 if (makeArrayRef(insertedPos) != 1184 llvm::makeArrayRef(extractPosition).take_front(extractedRank)) 1185 return failure(); 1186 // Case 2.a. early-exit fold. 1187 res = nextInsertOp.source(); 1188 // Case 2.b. if internal transposition is present, canFold will be false. 1189 return success(); 1190 } 1191 1192 /// Case 3: if inserted position is a prefix of extractPosition, 1193 /// extract a portion of the source of the insertion. 1194 /// This method updates the internal state. 1195 LogicalResult 1196 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { 1197 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1198 if (!isContainedWithin(insertedPos, extractPosition)) 1199 return failure(); 1200 // Set leading dims to zero. 1201 std::fill_n(extractPosition.begin(), insertedPos.size(), 0); 1202 // Drop extra leading dims. 1203 extractPosition.erase(extractPosition.begin(), 1204 extractPosition.begin() + insertedPos.size()); 1205 extractedRank = extractPosition.size() - sentinels.size(); 1206 // Case 3.a. early-exit fold (break and delegate to post-while path). 1207 res = nextInsertOp.source(); 1208 // Case 3.b. if internal transposition is present, canFold will be false. 1209 return success(); 1210 } 1211 1212 /// Try to fold in place to extract(source, extractPosition) and return the 1213 /// folded result. Return null if folding is not possible (e.g. due to an 1214 /// internal tranposition in the result). 1215 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( 1216 Value source) { 1217 // If we can't fold (either internal transposition, or nothing to fold), bail. 1218 bool nothingToFold = (source == extractOp.vector()); 1219 if (nothingToFold || !canFold()) 1220 return Value(); 1221 // Otherwise, fold by updating the op inplace and return its result. 1222 OpBuilder b(extractOp.getContext()); 1223 extractOp->setAttr( 1224 extractOp.positionAttrName(), 1225 b.getI64ArrayAttr( 1226 makeArrayRef(extractPosition).take_front(extractedRank))); 1227 extractOp.vectorMutable().assign(source); 1228 return extractOp.getResult(); 1229 } 1230 1231 /// Iterate over producing insert and transpose ops until we find a fold. 1232 Value ExtractFromInsertTransposeChainState::fold() { 1233 Value valueToExtractFrom = extractOp.vector(); 1234 updateStateForNextIteration(valueToExtractFrom); 1235 while (nextInsertOp || nextTransposeOp) { 1236 // Case 1. If we hit a transpose, just compose the map and iterate. 1237 // Invariant: insert + transpose do not change rank, we can always compose. 1238 if (succeeded(handleTransposeOp())) { 1239 valueToExtractFrom = nextTransposeOp.vector(); 1240 updateStateForNextIteration(valueToExtractFrom); 1241 continue; 1242 } 1243 1244 Value result; 1245 // Case 2: the position match exactly. 1246 if (succeeded(handleInsertOpWithMatchingPos(result))) 1247 return result; 1248 1249 // Case 3: if the inserted position is a prefix of extractPosition, we can 1250 // just extract a portion of the source of the insert. 1251 if (succeeded(handleInsertOpWithPrefixPos(result))) 1252 return tryToFoldExtractOpInPlace(result); 1253 1254 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel 1255 // values. This is a more difficult case and we bail. 1256 auto insertedPos = extractVector<int64_t>(nextInsertOp.position()); 1257 if (isContainedWithin(extractPosition, insertedPos) || 1258 intersectsWhereNonNegative(extractPosition, insertedPos)) 1259 return Value(); 1260 1261 // Case 5: No intersection, we forward the extract to insertOp.dest(). 1262 valueToExtractFrom = nextInsertOp.dest(); 1263 updateStateForNextIteration(valueToExtractFrom); 1264 } 1265 // If after all this we can fold, go for it. 1266 return tryToFoldExtractOpInPlace(valueToExtractFrom); 1267 } 1268 1269 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. 1270 static Value foldExtractFromBroadcast(ExtractOp extractOp) { 1271 Operation *defOp = extractOp.vector().getDefiningOp(); 1272 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1273 return Value(); 1274 Value source = defOp->getOperand(0); 1275 if (extractOp.getType() == source.getType()) 1276 return source; 1277 auto getRank = [](Type type) { 1278 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1279 }; 1280 unsigned broadcastSrcRank = getRank(source.getType()); 1281 unsigned extractResultRank = getRank(extractOp.getType()); 1282 if (extractResultRank < broadcastSrcRank) { 1283 auto extractPos = extractVector<int64_t>(extractOp.position()); 1284 unsigned rankDiff = broadcastSrcRank - extractResultRank; 1285 extractPos.erase( 1286 extractPos.begin(), 1287 std::next(extractPos.begin(), extractPos.size() - rankDiff)); 1288 extractOp.setOperand(source); 1289 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1290 OpBuilder b(extractOp.getContext()); 1291 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1292 b.getI64ArrayAttr(extractPos)); 1293 return extractOp.getResult(); 1294 } 1295 return Value(); 1296 } 1297 1298 // Fold extractOp with source coming from ShapeCast op. 1299 static Value foldExtractFromShapeCast(ExtractOp extractOp) { 1300 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); 1301 if (!shapeCastOp) 1302 return Value(); 1303 // Get the nth dimension size starting from lowest dimension. 1304 auto getDimReverse = [](VectorType type, int64_t n) { 1305 return type.getShape().take_back(n + 1).front(); 1306 }; 1307 int64_t destinationRank = 1308 extractOp.getType().isa<VectorType>() 1309 ? extractOp.getType().cast<VectorType>().getRank() 1310 : 0; 1311 if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) 1312 return Value(); 1313 if (destinationRank > 0) { 1314 auto destinationType = extractOp.getResult().getType().cast<VectorType>(); 1315 for (int64_t i = 0; i < destinationRank; i++) { 1316 // The lowest dimension of of the destination must match the lowest 1317 // dimension of the shapecast op source. 1318 // TODO: This case could be support in a canonicalization pattern. 1319 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != 1320 getDimReverse(destinationType, i)) 1321 return Value(); 1322 } 1323 } 1324 // Extract the strides associated with the extract op vector source. Then use 1325 // this to calculate a linearized position for the extract. 1326 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1327 std::reverse(extractedPos.begin(), extractedPos.end()); 1328 SmallVector<int64_t, 4> strides; 1329 int64_t stride = 1; 1330 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { 1331 strides.push_back(stride); 1332 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); 1333 } 1334 1335 int64_t position = linearize(extractedPos, strides); 1336 // Then extract the strides associated to the shapeCast op vector source and 1337 // delinearize the position using those strides. 1338 SmallVector<int64_t, 4> newStrides; 1339 int64_t numDimension = 1340 shapeCastOp.getSourceVectorType().getRank() - destinationRank; 1341 stride = 1; 1342 for (int64_t i = 0; i < numDimension; i++) { 1343 newStrides.push_back(stride); 1344 stride *= 1345 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); 1346 } 1347 std::reverse(newStrides.begin(), newStrides.end()); 1348 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position); 1349 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1350 OpBuilder b(extractOp.getContext()); 1351 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1352 b.getI64ArrayAttr(newPosition)); 1353 extractOp.setOperand(shapeCastOp.source()); 1354 return extractOp.getResult(); 1355 } 1356 1357 /// Fold an ExtractOp from ExtractStridedSliceOp. 1358 static Value foldExtractFromExtractStrided(ExtractOp extractOp) { 1359 auto extractStridedSliceOp = 1360 extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>(); 1361 if (!extractStridedSliceOp) 1362 return Value(); 1363 // Return if 'extractStridedSliceOp' has non-unit strides. 1364 if (extractStridedSliceOp.hasNonUnitStrides()) 1365 return Value(); 1366 1367 // Trim offsets for dimensions fully extracted. 1368 auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets()); 1369 while (!sliceOffsets.empty()) { 1370 size_t lastOffset = sliceOffsets.size() - 1; 1371 if (sliceOffsets.back() != 0 || 1372 extractStridedSliceOp.getType().getDimSize(lastOffset) != 1373 extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) 1374 break; 1375 sliceOffsets.pop_back(); 1376 } 1377 unsigned destinationRank = 0; 1378 if (auto vecType = extractOp.getType().dyn_cast<VectorType>()) 1379 destinationRank = vecType.getRank(); 1380 // The dimensions of the result need to be untouched by the 1381 // extractStridedSlice op. 1382 if (destinationRank > 1383 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) 1384 return Value(); 1385 auto extractedPos = extractVector<int64_t>(extractOp.position()); 1386 assert(extractedPos.size() >= sliceOffsets.size()); 1387 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) 1388 extractedPos[i] = extractedPos[i] + sliceOffsets[i]; 1389 extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); 1390 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1391 OpBuilder b(extractOp.getContext()); 1392 extractOp->setAttr(ExtractOp::getPositionAttrName(), 1393 b.getI64ArrayAttr(extractedPos)); 1394 return extractOp.getResult(); 1395 } 1396 1397 /// Fold extract_op fed from a chain of insertStridedSlice ops. 1398 static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { 1399 int64_t destinationRank = op.getType().isa<VectorType>() 1400 ? op.getType().cast<VectorType>().getRank() 1401 : 0; 1402 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 1403 while (insertOp) { 1404 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - 1405 insertOp.getSourceVectorType().getRank(); 1406 if (destinationRank > insertOp.getSourceVectorType().getRank()) 1407 return Value(); 1408 auto insertOffsets = extractVector<int64_t>(insertOp.offsets()); 1409 auto extractOffsets = extractVector<int64_t>(op.position()); 1410 1411 if (llvm::any_of(insertOp.strides(), [](Attribute attr) { 1412 return attr.cast<IntegerAttr>().getInt() != 1; 1413 })) 1414 return Value(); 1415 bool disjoint = false; 1416 SmallVector<int64_t, 4> offsetDiffs; 1417 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 1418 int64_t start = insertOffsets[dim]; 1419 int64_t size = 1420 (dim < insertRankDiff) 1421 ? 1 1422 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); 1423 int64_t end = start + size; 1424 int64_t offset = extractOffsets[dim]; 1425 // Check if the start of the extract offset is in the interval inserted. 1426 if (start <= offset && offset < end) { 1427 if (dim >= insertRankDiff) 1428 offsetDiffs.push_back(offset - start); 1429 continue; 1430 } 1431 disjoint = true; 1432 break; 1433 } 1434 // The extract element chunk overlap with the vector inserted. 1435 if (!disjoint) { 1436 // If any of the inner dimensions are only partially inserted we have a 1437 // partial overlap. 1438 int64_t srcRankDiff = 1439 insertOp.getSourceVectorType().getRank() - destinationRank; 1440 for (int64_t i = 0; i < destinationRank; i++) { 1441 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != 1442 insertOp.getDestVectorType().getDimSize(i + srcRankDiff + 1443 insertRankDiff)) 1444 return Value(); 1445 } 1446 op.vectorMutable().assign(insertOp.source()); 1447 // OpBuilder is only used as a helper to build an I64ArrayAttr. 1448 OpBuilder b(op.getContext()); 1449 op->setAttr(ExtractOp::getPositionAttrName(), 1450 b.getI64ArrayAttr(offsetDiffs)); 1451 return op.getResult(); 1452 } 1453 // If the chunk extracted is disjoint from the chunk inserted, keep 1454 // looking in the insert chain. 1455 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 1456 } 1457 return Value(); 1458 } 1459 1460 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { 1461 if (position().empty()) 1462 return vector(); 1463 if (succeeded(foldExtractOpFromExtractChain(*this))) 1464 return getResult(); 1465 if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) 1466 return res; 1467 if (auto res = foldExtractFromBroadcast(*this)) 1468 return res; 1469 if (auto res = foldExtractFromShapeCast(*this)) 1470 return res; 1471 if (auto val = foldExtractFromExtractStrided(*this)) 1472 return val; 1473 if (auto val = foldExtractStridedOpFromInsertChain(*this)) 1474 return val; 1475 return OpFoldResult(); 1476 } 1477 1478 namespace { 1479 1480 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. 1481 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { 1482 public: 1483 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1484 1485 LogicalResult matchAndRewrite(ExtractOp extractOp, 1486 PatternRewriter &rewriter) const override { 1487 Operation *defOp = extractOp.vector().getDefiningOp(); 1488 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) 1489 return failure(); 1490 Value source = defOp->getOperand(0); 1491 if (extractOp.getType() == source.getType()) 1492 return failure(); 1493 auto getRank = [](Type type) { 1494 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; 1495 }; 1496 unsigned broadcastSrcRank = getRank(source.getType()); 1497 unsigned extractResultRank = getRank(extractOp.getType()); 1498 // We only consider the case where the rank of the source is smaller than 1499 // the rank of the extract dst. The other cases are handled in the folding 1500 // patterns. 1501 if (extractResultRank <= broadcastSrcRank) 1502 return failure(); 1503 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1504 extractOp, extractOp.getType(), source); 1505 return success(); 1506 } 1507 }; 1508 1509 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. 1510 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> { 1511 public: 1512 using OpRewritePattern<ExtractOp>::OpRewritePattern; 1513 1514 LogicalResult matchAndRewrite(ExtractOp extractOp, 1515 PatternRewriter &rewriter) const override { 1516 // Return if 'extractStridedSliceOp' operand is not defined by a 1517 // ConstantOp. 1518 auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>(); 1519 if (!constantOp) 1520 return failure(); 1521 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 1522 if (!dense) 1523 return failure(); 1524 Attribute newAttr = dense.getSplatValue<Attribute>(); 1525 if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>()) 1526 newAttr = DenseElementsAttr::get(vecDstType, newAttr); 1527 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); 1528 return success(); 1529 } 1530 }; 1531 1532 } // namespace 1533 1534 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, 1535 MLIRContext *context) { 1536 results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context); 1537 } 1538 1539 static void populateFromInt64AttrArray(ArrayAttr arrayAttr, 1540 SmallVectorImpl<int64_t> &results) { 1541 for (auto attr : arrayAttr) 1542 results.push_back(attr.cast<IntegerAttr>().getInt()); 1543 } 1544 1545 //===----------------------------------------------------------------------===// 1546 // ExtractMapOp 1547 //===----------------------------------------------------------------------===// 1548 1549 void ExtractMapOp::build(OpBuilder &builder, OperationState &result, 1550 Value vector, ValueRange ids, 1551 ArrayRef<int64_t> multiplicity, 1552 AffineMap permutationMap) { 1553 assert(ids.size() == multiplicity.size() && 1554 ids.size() == permutationMap.getNumResults()); 1555 assert(permutationMap.isProjectedPermutation()); 1556 VectorType type = vector.getType().cast<VectorType>(); 1557 SmallVector<int64_t, 4> newShape(type.getShape().begin(), 1558 type.getShape().end()); 1559 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) { 1560 AffineExpr expr = permutationMap.getResult(i); 1561 auto dim = expr.cast<AffineDimExpr>(); 1562 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i]; 1563 } 1564 VectorType resultType = VectorType::get(newShape, type.getElementType()); 1565 ExtractMapOp::build(builder, result, resultType, vector, ids); 1566 } 1567 1568 static LogicalResult verify(ExtractMapOp op) { 1569 if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) 1570 return op.emitOpError( 1571 "expected source and destination vectors of same rank"); 1572 unsigned numId = 0; 1573 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) { 1574 if (op.getSourceVectorType().getDimSize(i) % 1575 op.getResultType().getDimSize(i) != 1576 0) 1577 return op.emitOpError("source vector dimensions must be a multiple of " 1578 "destination vector dimensions"); 1579 if (op.getSourceVectorType().getDimSize(i) != 1580 op.getResultType().getDimSize(i)) 1581 numId++; 1582 } 1583 if (numId != op.ids().size()) 1584 return op.emitOpError("expected number of ids must match the number of " 1585 "dimensions distributed"); 1586 return success(); 1587 } 1588 1589 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) { 1590 auto insert = vector().getDefiningOp<vector::InsertMapOp>(); 1591 if (insert == nullptr || getType() != insert.vector().getType() || 1592 ids() != insert.ids()) 1593 return {}; 1594 return insert.vector(); 1595 } 1596 1597 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) { 1598 assert(multiplicity.empty()); 1599 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) { 1600 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) 1601 multiplicity.push_back(getSourceVectorType().getDimSize(i) / 1602 getResultType().getDimSize(i)); 1603 } 1604 } 1605 1606 template <typename MapOp> 1607 AffineMap calculateImplicitMap(MapOp op) { 1608 SmallVector<AffineExpr, 4> perm; 1609 // Check which dimension have a multiplicity greater than 1 and associated 1610 // them to the IDs in order. 1611 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) { 1612 if (op.getSourceVectorType().getDimSize(i) != 1613 op.getResultType().getDimSize(i)) 1614 perm.push_back(getAffineDimExpr(i, op.getContext())); 1615 } 1616 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm, 1617 op.getContext()); 1618 return map; 1619 } 1620 1621 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } 1622 1623 //===----------------------------------------------------------------------===// 1624 // FmaOp 1625 //===----------------------------------------------------------------------===// 1626 1627 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { 1628 return llvm::to_vector<4>(getVectorType().getShape()); 1629 } 1630 1631 //===----------------------------------------------------------------------===// 1632 // BroadcastOp 1633 //===----------------------------------------------------------------------===// 1634 1635 BroadcastableToResult 1636 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, 1637 std::pair<int, int> *mismatchingDims) { 1638 // Broadcast scalar to vector of the same element type. 1639 if (srcType.isIntOrIndexOrFloat() && dstVectorType && 1640 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) 1641 return BroadcastableToResult::Success; 1642 // From now on, only vectors broadcast. 1643 VectorType srcVectorType = srcType.dyn_cast<VectorType>(); 1644 if (!srcVectorType) 1645 return BroadcastableToResult::SourceTypeNotAVector; 1646 1647 int64_t srcRank = srcVectorType.getRank(); 1648 int64_t dstRank = dstVectorType.getRank(); 1649 if (srcRank > dstRank) 1650 return BroadcastableToResult::SourceRankHigher; 1651 // Source has an exact match or singleton value for all trailing dimensions 1652 // (all leading dimensions are simply duplicated). 1653 int64_t lead = dstRank - srcRank; 1654 for (int64_t r = 0; r < srcRank; ++r) { 1655 int64_t srcDim = srcVectorType.getDimSize(r); 1656 int64_t dstDim = dstVectorType.getDimSize(lead + r); 1657 if (srcDim != 1 && srcDim != dstDim) { 1658 if (mismatchingDims) { 1659 mismatchingDims->first = srcDim; 1660 mismatchingDims->second = dstDim; 1661 } 1662 return BroadcastableToResult::DimensionMismatch; 1663 } 1664 } 1665 1666 return BroadcastableToResult::Success; 1667 } 1668 1669 static LogicalResult verify(BroadcastOp op) { 1670 std::pair<int, int> mismatchingDims; 1671 BroadcastableToResult res = isBroadcastableTo( 1672 op.getSourceType(), op.getVectorType(), &mismatchingDims); 1673 if (res == BroadcastableToResult::Success) 1674 return success(); 1675 if (res == BroadcastableToResult::SourceRankHigher) 1676 return op.emitOpError("source rank higher than destination rank"); 1677 if (res == BroadcastableToResult::DimensionMismatch) 1678 return op.emitOpError("dimension mismatch (") 1679 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; 1680 if (res == BroadcastableToResult::SourceTypeNotAVector) 1681 return op.emitOpError("source type is not a vector"); 1682 llvm_unreachable("unexpected vector.broadcast op error"); 1683 } 1684 1685 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 1686 if (getSourceType() == getVectorType()) 1687 return source(); 1688 if (!operands[0]) 1689 return {}; 1690 auto vectorType = getVectorType(); 1691 if (operands[0].getType().isIntOrIndexOrFloat()) 1692 return DenseElementsAttr::get(vectorType, operands[0]); 1693 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>()) 1694 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()); 1695 return {}; 1696 } 1697 1698 namespace { 1699 1700 // Fold broadcast1(broadcast2(x)) into broadcast1(x). 1701 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { 1702 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 1703 1704 LogicalResult matchAndRewrite(BroadcastOp broadcastOp, 1705 PatternRewriter &rewriter) const override { 1706 auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>(); 1707 if (!srcBroadcast) 1708 return failure(); 1709 rewriter.replaceOpWithNewOp<BroadcastOp>( 1710 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source()); 1711 return success(); 1712 } 1713 }; 1714 } // namespace 1715 1716 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 1717 MLIRContext *context) { 1718 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by 1719 // calling `populateCastAwayVectorLeadingOneDimPatterns` 1720 results.add<BroadcastFolder>(context); 1721 } 1722 1723 //===----------------------------------------------------------------------===// 1724 // ShuffleOp 1725 //===----------------------------------------------------------------------===// 1726 1727 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, 1728 Value v2, ArrayRef<int64_t> mask) { 1729 result.addOperands({v1, v2}); 1730 auto maskAttr = getVectorSubscriptAttr(builder, mask); 1731 auto v1Type = v1.getType().cast<VectorType>(); 1732 auto shape = llvm::to_vector<4>(v1Type.getShape()); 1733 shape[0] = mask.size(); 1734 result.addTypes(VectorType::get(shape, v1Type.getElementType())); 1735 result.addAttribute(getMaskAttrName(), maskAttr); 1736 } 1737 1738 static void print(OpAsmPrinter &p, ShuffleOp op) { 1739 p << " " << op.v1() << ", " << op.v2() << " " << op.mask(); 1740 p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()}); 1741 p << " : " << op.v1().getType() << ", " << op.v2().getType(); 1742 } 1743 1744 static LogicalResult verify(ShuffleOp op) { 1745 VectorType resultType = op.getVectorType(); 1746 VectorType v1Type = op.getV1VectorType(); 1747 VectorType v2Type = op.getV2VectorType(); 1748 // Verify ranks. 1749 int64_t resRank = resultType.getRank(); 1750 int64_t v1Rank = v1Type.getRank(); 1751 int64_t v2Rank = v2Type.getRank(); 1752 if (resRank != v1Rank || v1Rank != v2Rank) 1753 return op.emitOpError("rank mismatch"); 1754 // Verify all but leading dimension sizes. 1755 for (int64_t r = 1; r < v1Rank; ++r) { 1756 int64_t resDim = resultType.getDimSize(r); 1757 int64_t v1Dim = v1Type.getDimSize(r); 1758 int64_t v2Dim = v2Type.getDimSize(r); 1759 if (resDim != v1Dim || v1Dim != v2Dim) 1760 return op.emitOpError("dimension mismatch"); 1761 } 1762 // Verify mask length. 1763 auto maskAttr = op.mask().getValue(); 1764 int64_t maskLength = maskAttr.size(); 1765 if (maskLength != resultType.getDimSize(0)) 1766 return op.emitOpError("mask length mismatch"); 1767 // Verify all indices. 1768 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); 1769 for (const auto &en : llvm::enumerate(maskAttr)) { 1770 auto attr = en.value().dyn_cast<IntegerAttr>(); 1771 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) 1772 return op.emitOpError("mask index #") 1773 << (en.index() + 1) << " out of range"; 1774 } 1775 return success(); 1776 } 1777 1778 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { 1779 OpAsmParser::OperandType v1, v2; 1780 Attribute attr; 1781 VectorType v1Type, v2Type; 1782 if (parser.parseOperand(v1) || parser.parseComma() || 1783 parser.parseOperand(v2) || 1784 parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(), 1785 result.attributes) || 1786 parser.parseOptionalAttrDict(result.attributes) || 1787 parser.parseColonType(v1Type) || parser.parseComma() || 1788 parser.parseType(v2Type) || 1789 parser.resolveOperand(v1, v1Type, result.operands) || 1790 parser.resolveOperand(v2, v2Type, result.operands)) 1791 return failure(); 1792 // Construct resulting type: leading dimension matches mask length, 1793 // all trailing dimensions match the operands. 1794 auto maskAttr = attr.dyn_cast<ArrayAttr>(); 1795 if (!maskAttr) 1796 return parser.emitError(parser.getNameLoc(), "missing mask attribute"); 1797 int64_t maskLength = maskAttr.size(); 1798 if (maskLength <= 0) 1799 return parser.emitError(parser.getNameLoc(), "invalid mask length"); 1800 int64_t v1Rank = v1Type.getRank(); 1801 SmallVector<int64_t, 4> shape; 1802 shape.reserve(v1Rank); 1803 shape.push_back(maskLength); 1804 for (int64_t r = 1; r < v1Rank; ++r) 1805 shape.push_back(v1Type.getDimSize(r)); 1806 VectorType resType = VectorType::get(shape, v1Type.getElementType()); 1807 parser.addTypeToList(resType, result.types); 1808 return success(); 1809 } 1810 1811 //===----------------------------------------------------------------------===// 1812 // InsertElementOp 1813 //===----------------------------------------------------------------------===// 1814 1815 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1816 Value source, Value dest) { 1817 result.addOperands({source, dest}); 1818 result.addTypes(dest.getType()); 1819 } 1820 1821 void InsertElementOp::build(OpBuilder &builder, OperationState &result, 1822 Value source, Value dest, Value position) { 1823 result.addOperands({source, dest, position}); 1824 result.addTypes(dest.getType()); 1825 } 1826 1827 static LogicalResult verify(InsertElementOp op) { 1828 auto dstVectorType = op.getDestVectorType(); 1829 if (dstVectorType.getRank() == 0) { 1830 if (op.position()) 1831 return op.emitOpError("expected position to be empty with 0-D vector"); 1832 return success(); 1833 } 1834 if (dstVectorType.getRank() != 1) 1835 return op.emitOpError("unexpected >1 vector rank"); 1836 if (!op.position()) 1837 return op.emitOpError("expected position for 1-D vector"); 1838 return success(); 1839 } 1840 1841 //===----------------------------------------------------------------------===// 1842 // InsertOp 1843 //===----------------------------------------------------------------------===// 1844 1845 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1846 Value dest, ArrayRef<int64_t> position) { 1847 result.addOperands({source, dest}); 1848 auto positionAttr = getVectorSubscriptAttr(builder, position); 1849 result.addTypes(dest.getType()); 1850 result.addAttribute(getPositionAttrName(), positionAttr); 1851 } 1852 1853 // Convenience builder which assumes the values are constant indices. 1854 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, 1855 Value dest, ValueRange position) { 1856 SmallVector<int64_t, 4> positionConstants = 1857 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { 1858 return pos.getDefiningOp<arith::ConstantIndexOp>().value(); 1859 })); 1860 build(builder, result, source, dest, positionConstants); 1861 } 1862 1863 static LogicalResult verify(InsertOp op) { 1864 auto positionAttr = op.position().getValue(); 1865 auto destVectorType = op.getDestVectorType(); 1866 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank())) 1867 return op.emitOpError( 1868 "expected position attribute of rank smaller than dest vector rank"); 1869 auto srcVectorType = op.getSourceType().dyn_cast<VectorType>(); 1870 if (srcVectorType && 1871 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() != 1872 static_cast<unsigned>(destVectorType.getRank()))) 1873 return op.emitOpError("expected position attribute rank + source rank to " 1874 "match dest vector rank"); 1875 if (!srcVectorType && 1876 (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) 1877 return op.emitOpError( 1878 "expected position attribute rank to match the dest vector rank"); 1879 for (const auto &en : llvm::enumerate(positionAttr)) { 1880 auto attr = en.value().dyn_cast<IntegerAttr>(); 1881 if (!attr || attr.getInt() < 0 || 1882 attr.getInt() >= destVectorType.getDimSize(en.index())) 1883 return op.emitOpError("expected position attribute #") 1884 << (en.index() + 1) 1885 << " to be a non-negative integer smaller than the corresponding " 1886 "dest vector dimension"; 1887 } 1888 return success(); 1889 } 1890 1891 namespace { 1892 1893 // If insertOp is only inserting unit dimensions it can be transformed to a 1894 // broadcast. 1895 class InsertToBroadcast final : public OpRewritePattern<InsertOp> { 1896 public: 1897 using OpRewritePattern<InsertOp>::OpRewritePattern; 1898 1899 LogicalResult matchAndRewrite(InsertOp insertOp, 1900 PatternRewriter &rewriter) const override { 1901 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>(); 1902 if (!srcVecType || insertOp.getDestVectorType().getNumElements() != 1903 srcVecType.getNumElements()) 1904 return failure(); 1905 rewriter.replaceOpWithNewOp<BroadcastOp>( 1906 insertOp, insertOp.getDestVectorType(), insertOp.source()); 1907 return success(); 1908 } 1909 }; 1910 1911 } // namespace 1912 1913 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, 1914 MLIRContext *context) { 1915 results.add<InsertToBroadcast, BroadcastFolder>(context); 1916 } 1917 1918 // Eliminates insert operations that produce values identical to their source 1919 // value. This happens when the source and destination vectors have identical 1920 // sizes. 1921 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) { 1922 if (position().empty()) 1923 return source(); 1924 return {}; 1925 } 1926 1927 //===----------------------------------------------------------------------===// 1928 // InsertMapOp 1929 //===----------------------------------------------------------------------===// 1930 1931 void InsertMapOp::build(OpBuilder &builder, OperationState &result, 1932 Value vector, Value dest, ValueRange ids) { 1933 InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); 1934 } 1935 1936 static LogicalResult verify(InsertMapOp op) { 1937 if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) 1938 return op.emitOpError( 1939 "expected source and destination vectors of same rank"); 1940 unsigned numId = 0; 1941 for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) { 1942 if (op.getResultType().getDimSize(i) % 1943 op.getSourceVectorType().getDimSize(i) != 1944 0) 1945 return op.emitOpError( 1946 "destination vector size must be a multiple of source vector size"); 1947 if (op.getResultType().getDimSize(i) != 1948 op.getSourceVectorType().getDimSize(i)) 1949 numId++; 1950 } 1951 if (numId != op.ids().size()) 1952 return op.emitOpError("expected number of ids must match the number of " 1953 "dimensions distributed"); 1954 return success(); 1955 } 1956 1957 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); } 1958 1959 //===----------------------------------------------------------------------===// 1960 // InsertStridedSliceOp 1961 //===----------------------------------------------------------------------===// 1962 1963 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, 1964 Value source, Value dest, 1965 ArrayRef<int64_t> offsets, 1966 ArrayRef<int64_t> strides) { 1967 result.addOperands({source, dest}); 1968 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 1969 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 1970 result.addTypes(dest.getType()); 1971 result.addAttribute(getOffsetsAttrName(), offsetsAttr); 1972 result.addAttribute(getStridesAttrName(), stridesAttr); 1973 } 1974 1975 // TODO: Should be moved to Tablegen Confined attributes. 1976 template <typename OpType> 1977 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, 1978 ArrayAttr arrayAttr, 1979 ArrayRef<int64_t> shape, 1980 StringRef attrName) { 1981 if (arrayAttr.size() > shape.size()) 1982 return op.emitOpError("expected ") 1983 << attrName << " attribute of rank smaller than vector rank"; 1984 return success(); 1985 } 1986 1987 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 1988 // interval. If `halfOpen` is true then the admissible interval is [min, max). 1989 // Otherwise, the admissible interval is [min, max]. 1990 template <typename OpType> 1991 static LogicalResult 1992 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, 1993 int64_t max, StringRef attrName, 1994 bool halfOpen = true) { 1995 for (auto attr : arrayAttr) { 1996 auto val = attr.cast<IntegerAttr>().getInt(); 1997 auto upper = max; 1998 if (!halfOpen) 1999 upper += 1; 2000 if (val < min || val >= upper) 2001 return op.emitOpError("expected ") << attrName << " to be confined to [" 2002 << min << ", " << upper << ")"; 2003 } 2004 return success(); 2005 } 2006 2007 // Returns true if all integers in `arrayAttr` are in the half-open [min, max} 2008 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2009 // Otherwise, the admissible interval is [min, max]. 2010 template <typename OpType> 2011 static LogicalResult 2012 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, 2013 ArrayRef<int64_t> shape, StringRef attrName, 2014 bool halfOpen = true, int64_t min = 0) { 2015 assert(arrayAttr.size() <= shape.size()); 2016 unsigned index = 0; 2017 for (auto it : llvm::zip(arrayAttr, shape)) { 2018 auto val = std::get<0>(it).cast<IntegerAttr>().getInt(); 2019 auto max = std::get<1>(it); 2020 if (!halfOpen) 2021 max += 1; 2022 if (val < min || val >= max) 2023 return op.emitOpError("expected ") 2024 << attrName << " dimension " << index << " to be confined to [" 2025 << min << ", " << max << ")"; 2026 ++index; 2027 } 2028 return success(); 2029 } 2030 2031 // Returns true if all integers in `arrayAttr` are in the interval [min, max}. 2032 // interval. If `halfOpen` is true then the admissible interval is [min, max). 2033 // Otherwise, the admissible interval is [min, max]. 2034 template <typename OpType> 2035 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( 2036 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, 2037 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, 2038 bool halfOpen = true, int64_t min = 1) { 2039 assert(arrayAttr1.size() <= shape.size()); 2040 assert(arrayAttr2.size() <= shape.size()); 2041 unsigned index = 0; 2042 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) { 2043 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt(); 2044 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt(); 2045 auto max = std::get<2>(it); 2046 if (!halfOpen) 2047 max += 1; 2048 if (val1 + val2 < 0 || val1 + val2 >= max) 2049 return op.emitOpError("expected sum(") 2050 << attrName1 << ", " << attrName2 << ") dimension " << index 2051 << " to be confined to [" << min << ", " << max << ")"; 2052 ++index; 2053 } 2054 return success(); 2055 } 2056 2057 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, 2058 MLIRContext *context) { 2059 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { 2060 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); 2061 }); 2062 return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); 2063 } 2064 2065 static LogicalResult verify(InsertStridedSliceOp op) { 2066 auto sourceVectorType = op.getSourceVectorType(); 2067 auto destVectorType = op.getDestVectorType(); 2068 auto offsets = op.offsets(); 2069 auto strides = op.strides(); 2070 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) 2071 return op.emitOpError( 2072 "expected offsets of same size as destination vector rank"); 2073 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) 2074 return op.emitOpError( 2075 "expected strides of same size as source vector rank"); 2076 if (sourceVectorType.getRank() > destVectorType.getRank()) 2077 return op.emitOpError( 2078 "expected source rank to be smaller than destination rank"); 2079 2080 auto sourceShape = sourceVectorType.getShape(); 2081 auto destShape = destVectorType.getShape(); 2082 SmallVector<int64_t, 4> sourceShapeAsDestShape( 2083 destShape.size() - sourceShape.size(), 0); 2084 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); 2085 auto offName = InsertStridedSliceOp::getOffsetsAttrName(); 2086 auto stridesName = InsertStridedSliceOp::getStridesAttrName(); 2087 if (failed( 2088 isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) || 2089 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, 2090 /*halfOpen=*/false)) || 2091 failed(isSumOfIntegerArrayAttrConfinedToShape( 2092 op, offsets, 2093 makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape, 2094 offName, "source vector shape", 2095 /*halfOpen=*/false, /*min=*/1))) 2096 return failure(); 2097 2098 return success(); 2099 } 2100 2101 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2102 if (getSourceVectorType() == getDestVectorType()) 2103 return source(); 2104 return {}; 2105 } 2106 2107 //===----------------------------------------------------------------------===// 2108 // OuterProductOp 2109 //===----------------------------------------------------------------------===// 2110 2111 /// Build an op without mask, use the type of `acc` as the return type. 2112 void OuterProductOp::build(OpBuilder &builder, OperationState &result, 2113 Value lhs, Value rhs, Value acc) { 2114 result.addOperands({lhs, rhs, acc}); 2115 result.addTypes(acc.getType()); 2116 } 2117 2118 static void print(OpAsmPrinter &p, OuterProductOp op) { 2119 p << " " << op.lhs() << ", " << op.rhs(); 2120 if (!op.acc().empty()) { 2121 p << ", " << op.acc(); 2122 p.printOptionalAttrDict(op->getAttrs()); 2123 } 2124 p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); 2125 } 2126 2127 static ParseResult parseOuterProductOp(OpAsmParser &parser, 2128 OperationState &result) { 2129 SmallVector<OpAsmParser::OperandType, 3> operandsInfo; 2130 Type tLHS, tRHS; 2131 if (parser.parseOperandList(operandsInfo) || 2132 parser.parseOptionalAttrDict(result.attributes) || 2133 parser.parseColonType(tLHS) || parser.parseComma() || 2134 parser.parseType(tRHS)) 2135 return failure(); 2136 if (operandsInfo.size() < 2) 2137 return parser.emitError(parser.getNameLoc(), 2138 "expected at least 2 operands"); 2139 VectorType vLHS = tLHS.dyn_cast<VectorType>(); 2140 VectorType vRHS = tRHS.dyn_cast<VectorType>(); 2141 if (!vLHS) 2142 return parser.emitError(parser.getNameLoc(), 2143 "expected vector type for operand #1"); 2144 VectorType resType = 2145 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, 2146 vLHS.getElementType()) 2147 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); 2148 2149 if (!result.attributes.get(OuterProductOp::getKindAttrName())) { 2150 result.attributes.append( 2151 OuterProductOp::getKindAttrName(), 2152 CombiningKindAttr::get(OuterProductOp::getDefaultKind(), 2153 result.getContext())); 2154 } 2155 2156 return failure( 2157 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || 2158 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || 2159 (operandsInfo.size() > 2 && 2160 parser.resolveOperand(operandsInfo[2], resType, result.operands)) || 2161 parser.addTypeToList(resType, result.types)); 2162 } 2163 2164 static LogicalResult verify(OuterProductOp op) { 2165 Type tRHS = op.getOperandTypeRHS(); 2166 VectorType vLHS = op.getOperandVectorTypeLHS(), 2167 vRHS = tRHS.dyn_cast<VectorType>(), 2168 vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); 2169 2170 if (vLHS.getRank() != 1) 2171 return op.emitOpError("expected 1-d vector for operand #1"); 2172 2173 if (vRHS) { 2174 // Proper OUTER operation. 2175 if (vRHS.getRank() != 1) 2176 return op.emitOpError("expected 1-d vector for operand #2"); 2177 if (vRES.getRank() != 2) 2178 return op.emitOpError("expected 2-d vector result"); 2179 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2180 return op.emitOpError("expected #1 operand dim to match result dim #1"); 2181 if (vRHS.getDimSize(0) != vRES.getDimSize(1)) 2182 return op.emitOpError("expected #2 operand dim to match result dim #2"); 2183 } else { 2184 // An AXPY operation. 2185 if (vRES.getRank() != 1) 2186 return op.emitOpError("expected 1-d vector result"); 2187 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 2188 return op.emitOpError("expected #1 operand dim to match result dim #1"); 2189 } 2190 2191 if (vACC && vACC != vRES) 2192 return op.emitOpError("expected operand #3 of same type as result type"); 2193 2194 // Verify supported combining kind. 2195 if (!isSupportedCombiningKind(op.kind(), vRES.getElementType())) 2196 return op.emitOpError("unsupported outerproduct type"); 2197 2198 return success(); 2199 } 2200 2201 //===----------------------------------------------------------------------===// 2202 // ReshapeOp 2203 //===----------------------------------------------------------------------===// 2204 2205 static LogicalResult verify(ReshapeOp op) { 2206 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. 2207 auto inputVectorType = op.getInputVectorType(); 2208 auto outputVectorType = op.getOutputVectorType(); 2209 int64_t inputShapeRank = op.getNumInputShapeSizes(); 2210 int64_t outputShapeRank = op.getNumOutputShapeSizes(); 2211 SmallVector<int64_t, 4> fixedVectorSizes; 2212 op.getFixedVectorSizes(fixedVectorSizes); 2213 int64_t numFixedVectorSizes = fixedVectorSizes.size(); 2214 2215 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) 2216 return op.emitError("invalid input shape for vector type ") 2217 << inputVectorType; 2218 2219 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) 2220 return op.emitError("invalid output shape for vector type ") 2221 << outputVectorType; 2222 2223 // Verify that the 'fixedVectorSizes' match an input/output vector shape 2224 // suffix. 2225 unsigned inputVectorRank = inputVectorType.getRank(); 2226 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2227 unsigned index = inputVectorRank - numFixedVectorSizes - i; 2228 if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) 2229 return op.emitError("fixed vector size must match input vector for dim ") 2230 << i; 2231 } 2232 2233 unsigned outputVectorRank = outputVectorType.getRank(); 2234 for (unsigned i = 0; i < numFixedVectorSizes; ++i) { 2235 unsigned index = outputVectorRank - numFixedVectorSizes - i; 2236 if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) 2237 return op.emitError("fixed vector size must match output vector for dim ") 2238 << i; 2239 } 2240 2241 // If all shape operands are produced by constant ops, verify that product 2242 // of dimensions for input/output shape match. 2243 auto isDefByConstant = [](Value operand) { 2244 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 2245 }; 2246 if (llvm::all_of(op.input_shape(), isDefByConstant) && 2247 llvm::all_of(op.output_shape(), isDefByConstant)) { 2248 int64_t numInputElements = 1; 2249 for (auto operand : op.input_shape()) 2250 numInputElements *= 2251 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2252 int64_t numOutputElements = 1; 2253 for (auto operand : op.output_shape()) 2254 numOutputElements *= 2255 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); 2256 if (numInputElements != numOutputElements) 2257 return op.emitError("product of input and output shape sizes must match"); 2258 } 2259 return success(); 2260 } 2261 2262 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) { 2263 populateFromInt64AttrArray(fixed_vector_sizes(), results); 2264 } 2265 2266 //===----------------------------------------------------------------------===// 2267 // ExtractStridedSliceOp 2268 //===----------------------------------------------------------------------===// 2269 2270 // Inference works as follows: 2271 // 1. Add 'sizes' from prefix of dims in 'offsets'. 2272 // 2. Add sizes from 'vectorType' for remaining dims. 2273 static Type inferStridedSliceOpResultType(VectorType vectorType, 2274 ArrayAttr offsets, ArrayAttr sizes, 2275 ArrayAttr strides) { 2276 assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); 2277 SmallVector<int64_t, 4> shape; 2278 shape.reserve(vectorType.getRank()); 2279 unsigned idx = 0; 2280 for (unsigned e = offsets.size(); idx < e; ++idx) 2281 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt()); 2282 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) 2283 shape.push_back(vectorType.getShape()[idx]); 2284 2285 return VectorType::get(shape, vectorType.getElementType()); 2286 } 2287 2288 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, 2289 Value source, ArrayRef<int64_t> offsets, 2290 ArrayRef<int64_t> sizes, 2291 ArrayRef<int64_t> strides) { 2292 result.addOperands(source); 2293 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); 2294 auto sizesAttr = getVectorSubscriptAttr(builder, sizes); 2295 auto stridesAttr = getVectorSubscriptAttr(builder, strides); 2296 result.addTypes( 2297 inferStridedSliceOpResultType(source.getType().cast<VectorType>(), 2298 offsetsAttr, sizesAttr, stridesAttr)); 2299 result.addAttribute(getOffsetsAttrName(), offsetsAttr); 2300 result.addAttribute(getSizesAttrName(), sizesAttr); 2301 result.addAttribute(getStridesAttrName(), stridesAttr); 2302 } 2303 2304 static LogicalResult verify(ExtractStridedSliceOp op) { 2305 auto type = op.getVectorType(); 2306 auto offsets = op.offsets(); 2307 auto sizes = op.sizes(); 2308 auto strides = op.strides(); 2309 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) { 2310 op.emitOpError( 2311 "expected offsets, sizes and strides attributes of same size"); 2312 return failure(); 2313 } 2314 2315 auto shape = type.getShape(); 2316 auto offName = ExtractStridedSliceOp::getOffsetsAttrName(); 2317 auto sizesName = ExtractStridedSliceOp::getSizesAttrName(); 2318 auto stridesName = ExtractStridedSliceOp::getStridesAttrName(); 2319 if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) || 2320 failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) || 2321 failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape, 2322 stridesName)) || 2323 failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) || 2324 failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName, 2325 /*halfOpen=*/false, 2326 /*min=*/1)) || 2327 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, 2328 /*halfOpen=*/false)) || 2329 failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape, 2330 offName, sizesName, 2331 /*halfOpen=*/false))) 2332 return failure(); 2333 2334 auto resultType = inferStridedSliceOpResultType( 2335 op.getVectorType(), op.offsets(), op.sizes(), op.strides()); 2336 if (op.getResult().getType() != resultType) { 2337 op.emitOpError("expected result type to be ") << resultType; 2338 return failure(); 2339 } 2340 2341 return success(); 2342 } 2343 2344 // When the source of ExtractStrided comes from a chain of InsertStrided ops try 2345 // to use the source of the InsertStrided ops if we can detect that the 2346 // extracted vector is a subset of one of the vector inserted. 2347 static LogicalResult 2348 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { 2349 // Helper to extract integer out of ArrayAttr. 2350 auto getElement = [](ArrayAttr array, int idx) { 2351 return array[idx].cast<IntegerAttr>().getInt(); 2352 }; 2353 ArrayAttr extractOffsets = op.offsets(); 2354 ArrayAttr extractStrides = op.strides(); 2355 ArrayAttr extractSizes = op.sizes(); 2356 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>(); 2357 while (insertOp) { 2358 if (op.getVectorType().getRank() != 2359 insertOp.getSourceVectorType().getRank()) 2360 return failure(); 2361 ArrayAttr insertOffsets = insertOp.offsets(); 2362 ArrayAttr insertStrides = insertOp.strides(); 2363 // If the rank of extract is greater than the rank of insert, we are likely 2364 // extracting a partial chunk of the vector inserted. 2365 if (extractOffsets.size() > insertOffsets.size()) 2366 return failure(); 2367 bool patialoverlap = false; 2368 bool disjoint = false; 2369 SmallVector<int64_t, 4> offsetDiffs; 2370 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { 2371 if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) 2372 return failure(); 2373 int64_t start = getElement(insertOffsets, dim); 2374 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); 2375 int64_t offset = getElement(extractOffsets, dim); 2376 int64_t size = getElement(extractSizes, dim); 2377 // Check if the start of the extract offset is in the interval inserted. 2378 if (start <= offset && offset < end) { 2379 // If the extract interval overlaps but is not fully included we may 2380 // have a partial overlap that will prevent any folding. 2381 if (offset + size > end) 2382 patialoverlap = true; 2383 offsetDiffs.push_back(offset - start); 2384 continue; 2385 } 2386 disjoint = true; 2387 break; 2388 } 2389 // The extract element chunk is a subset of the insert element. 2390 if (!disjoint && !patialoverlap) { 2391 op.setOperand(insertOp.source()); 2392 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2393 OpBuilder b(op.getContext()); 2394 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(), 2395 b.getI64ArrayAttr(offsetDiffs)); 2396 return success(); 2397 } 2398 // If the chunk extracted is disjoint from the chunk inserted, keep looking 2399 // in the insert chain. 2400 if (disjoint) 2401 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>(); 2402 else { 2403 // The extracted vector partially overlap the inserted vector, we cannot 2404 // fold. 2405 return failure(); 2406 } 2407 } 2408 return failure(); 2409 } 2410 2411 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) { 2412 if (getVectorType() == getResult().getType()) 2413 return vector(); 2414 if (succeeded(foldExtractStridedOpFromInsertChain(*this))) 2415 return getResult(); 2416 return {}; 2417 } 2418 2419 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { 2420 populateFromInt64AttrArray(offsets(), results); 2421 } 2422 2423 namespace { 2424 2425 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to 2426 // ConstantMaskOp. 2427 class StridedSliceConstantMaskFolder final 2428 : public OpRewritePattern<ExtractStridedSliceOp> { 2429 public: 2430 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2431 2432 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2433 PatternRewriter &rewriter) const override { 2434 // Return if 'extractStridedSliceOp' operand is not defined by a 2435 // ConstantMaskOp. 2436 auto *defOp = extractStridedSliceOp.vector().getDefiningOp(); 2437 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); 2438 if (!constantMaskOp) 2439 return failure(); 2440 // Return if 'extractStridedSliceOp' has non-unit strides. 2441 if (extractStridedSliceOp.hasNonUnitStrides()) 2442 return failure(); 2443 // Gather constant mask dimension sizes. 2444 SmallVector<int64_t, 4> maskDimSizes; 2445 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); 2446 // Gather strided slice offsets and sizes. 2447 SmallVector<int64_t, 4> sliceOffsets; 2448 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets); 2449 SmallVector<int64_t, 4> sliceSizes; 2450 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes); 2451 2452 // Compute slice of vector mask region. 2453 SmallVector<int64_t, 4> sliceMaskDimSizes; 2454 assert(sliceOffsets.size() == maskDimSizes.size()); 2455 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { 2456 int64_t maskDimSize = std::get<0>(it); 2457 int64_t sliceOffset = std::get<1>(it); 2458 int64_t sliceSize = std::get<2>(it); 2459 int64_t sliceMaskDimSize = std::max( 2460 static_cast<int64_t>(0), 2461 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); 2462 sliceMaskDimSizes.push_back(sliceMaskDimSize); 2463 } 2464 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked 2465 // region is a conjunction of mask dim intervals). 2466 if (llvm::is_contained(sliceMaskDimSizes, 0)) 2467 sliceMaskDimSizes.assign(maskDimSizes.size(), 0); 2468 2469 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask 2470 // region. 2471 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 2472 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), 2473 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); 2474 return success(); 2475 } 2476 }; 2477 2478 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. 2479 class StridedSliceConstantFolder final 2480 : public OpRewritePattern<ExtractStridedSliceOp> { 2481 public: 2482 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2483 2484 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, 2485 PatternRewriter &rewriter) const override { 2486 // Return if 'extractStridedSliceOp' operand is not defined by a 2487 // ConstantOp. 2488 auto constantOp = 2489 extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>(); 2490 if (!constantOp) 2491 return failure(); 2492 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 2493 if (!dense) 2494 return failure(); 2495 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), 2496 dense.getSplatValue<Attribute>()); 2497 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, 2498 newAttr); 2499 return success(); 2500 } 2501 }; 2502 2503 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to 2504 // BroadcastOp(ExtractStrideSliceOp). 2505 class StridedSliceBroadcast final 2506 : public OpRewritePattern<ExtractStridedSliceOp> { 2507 public: 2508 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2509 2510 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2511 PatternRewriter &rewriter) const override { 2512 auto broadcast = op.vector().getDefiningOp<BroadcastOp>(); 2513 if (!broadcast) 2514 return failure(); 2515 auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>(); 2516 unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; 2517 auto dstVecType = op.getType().cast<VectorType>(); 2518 unsigned dstRank = dstVecType.getRank(); 2519 unsigned rankDiff = dstRank - srcRrank; 2520 // Check if the most inner dimensions of the source of the broadcast are the 2521 // same as the destination of the extract. If this is the case we can just 2522 // use a broadcast as the original dimensions are untouched. 2523 bool lowerDimMatch = true; 2524 for (unsigned i = 0; i < srcRrank; i++) { 2525 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { 2526 lowerDimMatch = false; 2527 break; 2528 } 2529 } 2530 Value source = broadcast.source(); 2531 if (!lowerDimMatch) { 2532 // The inner dimensions don't match, it means we need to extract from the 2533 // source of the orignal broadcast and then broadcast the extracted value. 2534 source = rewriter.create<ExtractStridedSliceOp>( 2535 op->getLoc(), source, 2536 getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), 2537 getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), 2538 getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); 2539 } 2540 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); 2541 return success(); 2542 } 2543 }; 2544 2545 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. 2546 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { 2547 public: 2548 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 2549 2550 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 2551 PatternRewriter &rewriter) const override { 2552 auto splat = op.vector().getDefiningOp<SplatOp>(); 2553 if (!splat) 2554 return failure(); 2555 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput()); 2556 return success(); 2557 } 2558 }; 2559 2560 } // namespace 2561 2562 void ExtractStridedSliceOp::getCanonicalizationPatterns( 2563 RewritePatternSet &results, MLIRContext *context) { 2564 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> 2565 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. 2566 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder, 2567 StridedSliceBroadcast, StridedSliceSplat>(context); 2568 } 2569 2570 //===----------------------------------------------------------------------===// 2571 // TransferReadOp 2572 //===----------------------------------------------------------------------===// 2573 2574 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). 2575 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2576 VectorType vectorType, Value source, 2577 ValueRange indices, AffineMapAttr permutationMapAttr, 2578 /*optional*/ ArrayAttr inBoundsAttr) { 2579 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2580 Value padding = builder.create<arith::ConstantOp>( 2581 result.location, elemType, builder.getZeroAttr(elemType)); 2582 build(builder, result, vectorType, source, indices, permutationMapAttr, 2583 padding, /*mask=*/Value(), inBoundsAttr); 2584 } 2585 2586 /// 2. Builder that sets padding to zero an empty mask (variant without attrs). 2587 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2588 VectorType vectorType, Value source, 2589 ValueRange indices, AffineMap permutationMap, 2590 Optional<ArrayRef<bool>> inBounds) { 2591 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2592 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2593 ? builder.getBoolArrayAttr(inBounds.getValue()) 2594 : ArrayAttr(); 2595 build(builder, result, vectorType, source, indices, permutationMapAttr, 2596 inBoundsAttr); 2597 } 2598 2599 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. 2600 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2601 VectorType vectorType, Value source, 2602 ValueRange indices, Value padding, 2603 Optional<ArrayRef<bool>> inBounds) { 2604 AffineMap permutationMap = getTransferMinorIdentityMap( 2605 source.getType().cast<ShapedType>(), vectorType); 2606 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 2607 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 2608 ? builder.getBoolArrayAttr(inBounds.getValue()) 2609 : ArrayAttr(); 2610 build(builder, result, vectorType, source, indices, permutationMapAttr, 2611 padding, 2612 /*mask=*/Value(), inBoundsAttr); 2613 } 2614 2615 /// 4. Builder that sets padding to zero and permutation map to 2616 /// 'getMinorIdentityMap'. 2617 void TransferReadOp::build(OpBuilder &builder, OperationState &result, 2618 VectorType vectorType, Value source, 2619 ValueRange indices, 2620 Optional<ArrayRef<bool>> inBounds) { 2621 Type elemType = source.getType().cast<ShapedType>().getElementType(); 2622 Value padding = builder.create<arith::ConstantOp>( 2623 result.location, elemType, builder.getZeroAttr(elemType)); 2624 build(builder, result, vectorType, source, indices, padding, inBounds); 2625 } 2626 2627 template <typename EmitFun> 2628 static LogicalResult verifyPermutationMap(AffineMap permutationMap, 2629 EmitFun emitOpError) { 2630 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); 2631 for (auto expr : permutationMap.getResults()) { 2632 auto dim = expr.dyn_cast<AffineDimExpr>(); 2633 auto zero = expr.dyn_cast<AffineConstantExpr>(); 2634 if (zero) { 2635 if (zero.getValue() != 0) { 2636 return emitOpError( 2637 "requires a projected permutation_map (at most one dim or the zero " 2638 "constant can appear in each result)"); 2639 } 2640 continue; 2641 } 2642 if (!dim) { 2643 return emitOpError("requires a projected permutation_map (at most one " 2644 "dim or the zero constant can appear in each result)"); 2645 } 2646 if (seen[dim.getPosition()]) { 2647 return emitOpError( 2648 "requires a permutation_map that is a permutation (found one dim " 2649 "used more than once)"); 2650 } 2651 seen[dim.getPosition()] = true; 2652 } 2653 return success(); 2654 } 2655 2656 static LogicalResult 2657 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, 2658 VectorType vectorType, VectorType maskType, 2659 AffineMap permutationMap, ArrayAttr inBounds) { 2660 if (op->hasAttr("masked")) { 2661 return op->emitOpError("masked attribute has been removed. " 2662 "Use in_bounds instead."); 2663 } 2664 2665 if (!shapedType.isa<MemRefType, RankedTensorType>()) 2666 return op->emitOpError( 2667 "requires source to be a memref or ranked tensor type"); 2668 2669 auto elementType = shapedType.getElementType(); 2670 DataLayout dataLayout = DataLayout::closest(op); 2671 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) { 2672 // Memref or tensor has vector element type. 2673 unsigned sourceVecSize = 2674 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * 2675 vectorElementType.getShape().back(); 2676 unsigned resultVecSize = 2677 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * 2678 vectorType.getShape().back(); 2679 if (resultVecSize % sourceVecSize != 0) 2680 return op->emitOpError( 2681 "requires the bitwidth of the minor 1-D vector to be an integral " 2682 "multiple of the bitwidth of the minor 1-D vector of the source"); 2683 2684 unsigned sourceVecEltRank = vectorElementType.getRank(); 2685 unsigned resultVecRank = vectorType.getRank(); 2686 if (sourceVecEltRank > resultVecRank) 2687 return op->emitOpError( 2688 "requires source vector element and vector result ranks to match."); 2689 unsigned rankOffset = resultVecRank - sourceVecEltRank; 2690 // Check that permutation map results match 'rankOffset' of vector type. 2691 if (permutationMap.getNumResults() != rankOffset) 2692 return op->emitOpError("requires a permutation_map with result dims of " 2693 "the same rank as the vector type"); 2694 2695 if (maskType) 2696 return op->emitOpError("does not support masks with vector element type"); 2697 } else { 2698 // Memref or tensor has scalar element type. 2699 unsigned minorSize = 2700 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); 2701 unsigned resultVecSize = 2702 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; 2703 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) 2704 return op->emitOpError( 2705 "requires the bitwidth of the minor 1-D vector to be an integral " 2706 "multiple of the bitwidth of the source element type"); 2707 2708 // Check that permutation map results match rank of vector type. 2709 if (permutationMap.getNumResults() != vectorType.getRank()) 2710 return op->emitOpError("requires a permutation_map with result dims of " 2711 "the same rank as the vector type"); 2712 2713 VectorType expectedMaskType = 2714 vector::detail::transferMaskType(vectorType, permutationMap); 2715 if (maskType && expectedMaskType != maskType) 2716 return op->emitOpError("expects mask type consistent with permutation " 2717 "map: ") 2718 << maskType; 2719 } 2720 2721 if (permutationMap.getNumSymbols() != 0) 2722 return op->emitOpError("requires permutation_map without symbols"); 2723 2724 if (permutationMap.getNumInputs() != shapedType.getRank()) 2725 return op->emitOpError("requires a permutation_map with input dims of the " 2726 "same rank as the source type"); 2727 2728 if (inBounds) { 2729 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size())) 2730 return op->emitOpError("expects the optional in_bounds attr of same rank " 2731 "as permutation_map results: ") 2732 << AffineMapAttr::get(permutationMap) 2733 << " vs inBounds of size: " << inBounds.size(); 2734 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) 2735 if (permutationMap.getResult(i).isa<AffineConstantExpr>() && 2736 !inBounds.getValue()[i].cast<BoolAttr>().getValue()) 2737 return op->emitOpError("requires broadcast dimensions to be in-bounds"); 2738 } 2739 2740 return success(); 2741 } 2742 2743 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { 2744 SmallVector<StringRef, 3> elidedAttrs; 2745 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); 2746 if (op.permutation_map().isMinorIdentity()) 2747 elidedAttrs.push_back(op.getPermutationMapAttrName()); 2748 bool elideInBounds = true; 2749 if (auto inBounds = op.in_bounds()) { 2750 for (auto attr : *inBounds) { 2751 if (attr.template cast<BoolAttr>().getValue()) { 2752 elideInBounds = false; 2753 break; 2754 } 2755 } 2756 } 2757 if (elideInBounds) 2758 elidedAttrs.push_back(op.getInBoundsAttrName()); 2759 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); 2760 } 2761 2762 static void print(OpAsmPrinter &p, TransferReadOp op) { 2763 p << " " << op.source() << "[" << op.indices() << "], " << op.padding(); 2764 if (op.mask()) 2765 p << ", " << op.mask(); 2766 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation())); 2767 p << " : " << op.getShapedType() << ", " << op.getVectorType(); 2768 } 2769 2770 static ParseResult parseTransferReadOp(OpAsmParser &parser, 2771 OperationState &result) { 2772 auto &builder = parser.getBuilder(); 2773 SMLoc typesLoc; 2774 OpAsmParser::OperandType sourceInfo; 2775 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 2776 OpAsmParser::OperandType paddingInfo; 2777 SmallVector<Type, 2> types; 2778 OpAsmParser::OperandType maskInfo; 2779 // Parsing with support for paddingValue. 2780 if (parser.parseOperand(sourceInfo) || 2781 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 2782 parser.parseComma() || parser.parseOperand(paddingInfo)) 2783 return failure(); 2784 ParseResult hasMask = parser.parseOptionalComma(); 2785 if (hasMask.succeeded()) { 2786 parser.parseOperand(maskInfo); 2787 } 2788 if (parser.parseOptionalAttrDict(result.attributes) || 2789 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 2790 return failure(); 2791 if (types.size() != 2) 2792 return parser.emitError(typesLoc, "requires two types"); 2793 auto indexType = builder.getIndexType(); 2794 auto shapedType = types[0].dyn_cast<ShapedType>(); 2795 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 2796 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 2797 VectorType vectorType = types[1].dyn_cast<VectorType>(); 2798 if (!vectorType) 2799 return parser.emitError(typesLoc, "requires vector type"); 2800 auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); 2801 Attribute mapAttr = result.attributes.get(permutationAttrName); 2802 if (!mapAttr) { 2803 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 2804 // Update `mapAttr` that is used later to determine mask type. 2805 mapAttr = AffineMapAttr::get(permMap); 2806 result.attributes.set(permutationAttrName, mapAttr); 2807 } 2808 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || 2809 parser.resolveOperands(indexInfo, indexType, result.operands) || 2810 parser.resolveOperand(paddingInfo, shapedType.getElementType(), 2811 result.operands)) 2812 return failure(); 2813 if (hasMask.succeeded()) { 2814 if (shapedType.getElementType().dyn_cast<VectorType>()) 2815 return parser.emitError( 2816 maskInfo.location, "does not support masks with vector element type"); 2817 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue(); 2818 // Instead of adding the mask type as an op type, compute it based on the 2819 // vector type and the permutation map (to keep the type signature small). 2820 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); 2821 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 2822 return failure(); 2823 } 2824 result.addAttribute( 2825 TransferReadOp::getOperandSegmentSizeAttr(), 2826 builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1, 2827 static_cast<int32_t>(hasMask.succeeded())})); 2828 return parser.addTypeToList(vectorType, result.types); 2829 } 2830 2831 static LogicalResult verify(TransferReadOp op) { 2832 // Consistency of elemental types in source and vector. 2833 ShapedType shapedType = op.getShapedType(); 2834 VectorType vectorType = op.getVectorType(); 2835 VectorType maskType = op.getMaskType(); 2836 auto paddingType = op.padding().getType(); 2837 auto permutationMap = op.permutation_map(); 2838 auto sourceElementType = shapedType.getElementType(); 2839 2840 if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank()) 2841 return op.emitOpError("requires ") << shapedType.getRank() << " indices"; 2842 2843 if (failed( 2844 verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()), 2845 shapedType, vectorType, maskType, permutationMap, 2846 op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) 2847 return failure(); 2848 2849 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) { 2850 // Source has vector element type. 2851 // Check that 'sourceVectorElementType' and 'paddingType' types match. 2852 if (sourceVectorElementType != paddingType) 2853 return op.emitOpError( 2854 "requires source element type and padding type to match."); 2855 2856 } else { 2857 // Check that 'paddingType' is valid to store in a vector type. 2858 if (!VectorType::isValidElementType(paddingType)) 2859 return op.emitOpError("requires valid padding vector elemental type"); 2860 2861 // Check that padding type and vector element types match. 2862 if (paddingType != sourceElementType) 2863 return op.emitOpError( 2864 "requires formal padding and source of the same elemental type"); 2865 } 2866 2867 return verifyPermutationMap(permutationMap, 2868 [&op](Twine t) { return op.emitOpError(t); }); 2869 } 2870 2871 /// This is a common class used for patterns of the form 2872 /// ``` 2873 /// someop(memrefcast) -> someop 2874 /// ``` 2875 /// It folds the source of the memref.cast into the root operation directly. 2876 static LogicalResult foldMemRefCast(Operation *op) { 2877 bool folded = false; 2878 for (OpOperand &operand : op->getOpOperands()) { 2879 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 2880 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 2881 operand.set(castOp.getOperand()); 2882 folded = true; 2883 } 2884 } 2885 return success(folded); 2886 } 2887 2888 static LogicalResult foldTensorCast(Operation *op) { 2889 bool folded = false; 2890 for (OpOperand &operand : op->getOpOperands()) { 2891 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 2892 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 2893 operand.set(castOp.getOperand()); 2894 folded = true; 2895 } 2896 } 2897 return success(folded); 2898 } 2899 2900 template <typename TransferOp> 2901 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { 2902 // TODO: support more aggressive createOrFold on: 2903 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` 2904 if (op.getShapedType().isDynamicDim(indicesIdx)) 2905 return false; 2906 Value index = op.indices()[indicesIdx]; 2907 auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>(); 2908 if (!cstOp) 2909 return false; 2910 2911 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); 2912 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); 2913 2914 return cstOp.value() + vectorSize <= sourceSize; 2915 } 2916 2917 template <typename TransferOp> 2918 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { 2919 // TODO: support 0-d corner case. 2920 // TODO: Be less conservative. 2921 if (op.getTransferRank() == 0) 2922 return failure(); 2923 AffineMap permutationMap = op.permutation_map(); 2924 bool changed = false; 2925 SmallVector<bool, 4> newInBounds; 2926 newInBounds.reserve(op.getTransferRank()); 2927 for (unsigned i = 0; i < op.getTransferRank(); ++i) { 2928 // Already marked as in-bounds, nothing to see here. 2929 if (op.isDimInBounds(i)) { 2930 newInBounds.push_back(true); 2931 continue; 2932 } 2933 // Currently out-of-bounds, check whether we can statically determine it is 2934 // inBounds. 2935 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>(); 2936 assert(dimExpr && "Broadcast dims must be in-bounds"); 2937 auto inBounds = 2938 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); 2939 newInBounds.push_back(inBounds); 2940 // We commit the pattern if it is "more inbounds". 2941 changed |= inBounds; 2942 } 2943 if (!changed) 2944 return failure(); 2945 // OpBuilder is only used as a helper to build an I64ArrayAttr. 2946 OpBuilder b(op.getContext()); 2947 op->setAttr(TransferOp::getInBoundsAttrName(), 2948 b.getBoolArrayAttr(newInBounds)); 2949 return success(); 2950 } 2951 2952 /// ``` 2953 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 2954 /// : vector<1x4xf32>, tensor<4x4xf32> 2955 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} 2956 /// : tensor<4x4xf32>, vector<1x4xf32> 2957 /// ``` 2958 /// -> Folds into 2959 /// ``` 2960 /// %v0 2961 /// ``` 2962 static Value foldRAW(TransferReadOp readOp) { 2963 if (!readOp.getShapedType().isa<RankedTensorType>()) 2964 return {}; 2965 auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>(); 2966 while (defWrite) { 2967 if (checkSameValueRAW(defWrite, readOp)) 2968 return defWrite.vector(); 2969 if (!isDisjointTransferIndices( 2970 cast<VectorTransferOpInterface>(defWrite.getOperation()), 2971 cast<VectorTransferOpInterface>(readOp.getOperation()))) 2972 break; 2973 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 2974 } 2975 return {}; 2976 } 2977 2978 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { 2979 if (Value vec = foldRAW(*this)) 2980 return vec; 2981 /// transfer_read(memrefcast) -> transfer_read 2982 if (succeeded(foldTransferInBoundsAttribute(*this))) 2983 return getResult(); 2984 if (succeeded(foldMemRefCast(*this))) 2985 return getResult(); 2986 if (succeeded(foldTensorCast(*this))) 2987 return getResult(); 2988 return OpFoldResult(); 2989 } 2990 2991 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { 2992 return llvm::to_vector<4>(getVectorType().getShape()); 2993 } 2994 2995 void TransferReadOp::getEffects( 2996 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 2997 &effects) { 2998 if (getShapedType().isa<MemRefType>()) 2999 effects.emplace_back(MemoryEffects::Read::get(), source(), 3000 SideEffects::DefaultResource::get()); 3001 } 3002 3003 namespace { 3004 /// Fold transfer_reads of a tensor.extract_slice op. E.g.: 3005 /// 3006 /// ``` 3007 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] 3008 /// : tensor<?x?xf32> to tensor<?x?xf32> 3009 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} 3010 /// : tensor<?x?xf32>, vector<4x5xf32> 3011 /// ``` 3012 /// is rewritten to: 3013 /// ``` 3014 /// %p0 = arith.addi %a, %e : index 3015 /// %p1 = arith.addi %b, %f : index 3016 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} 3017 /// : tensor<?x?xf32>, vector<4x5xf32> 3018 /// ``` 3019 struct FoldExtractSliceIntoTransferRead 3020 : public OpRewritePattern<TransferReadOp> { 3021 public: 3022 using OpRewritePattern<TransferReadOp>::OpRewritePattern; 3023 3024 LogicalResult matchAndRewrite(TransferReadOp xferOp, 3025 PatternRewriter &rewriter) const override { 3026 // TODO: support 0-d corner case. 3027 if (xferOp.getTransferRank() == 0) 3028 return failure(); 3029 if (xferOp.hasOutOfBoundsDim()) 3030 return failure(); 3031 if (!xferOp.permutation_map().isIdentity()) 3032 return failure(); 3033 if (xferOp.mask()) 3034 return failure(); 3035 auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>(); 3036 if (!extractOp) 3037 return failure(); 3038 if (!extractOp.hasUnitStride()) 3039 return failure(); 3040 3041 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3042 // dims are exactly the leading dims. I.e. the following is illegal: 3043 // ``` 3044 // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : 3045 // tensor<2x1x4xf32> to tensor<2x4xf32> 3046 // %1 = vector.transfer_read %0[0,0], %cst : 3047 // tensor<2x4xf32>, vector<2x4xf32> 3048 // ``` 3049 // 3050 // Cannot fold into: 3051 // ``` 3052 // %0 = vector.transfer_read %t[0,0,0], %cst : 3053 // tensor<2x1x4xf32>, vector<2x4xf32> 3054 // ``` 3055 // For this, check the trailing `vectorRank` dims of the extract_slice 3056 // result tensor match the trailing dims of the inferred result tensor. 3057 int64_t rankReduced = 3058 extractOp.getSourceType().getRank() - extractOp.getType().getRank(); 3059 int64_t vectorRank = xferOp.getVectorType().getRank(); 3060 RankedTensorType inferredDestTensorType = 3061 tensor::ExtractSliceOp::inferResultType( 3062 extractOp.getSourceType(), extractOp.getMixedOffsets(), 3063 extractOp.getMixedSizes(), extractOp.getMixedStrides()); 3064 auto actualDestTensorShape = extractOp.getType().getShape(); 3065 if (rankReduced > 0 && 3066 actualDestTensorShape.take_back(vectorRank) != 3067 inferredDestTensorType.getShape().take_back(vectorRank)) 3068 return failure(); 3069 3070 SmallVector<Value> newIndices; 3071 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced 3072 // indices first. 3073 for (int64_t i = 0; i < rankReduced; ++i) { 3074 OpFoldResult offset = extractOp.getMixedOffsets()[i]; 3075 newIndices.push_back(getValueOrCreateConstantIndexOp( 3076 rewriter, extractOp.getLoc(), offset)); 3077 } 3078 for (const auto &it : llvm::enumerate(xferOp.indices())) { 3079 OpFoldResult offset = 3080 extractOp.getMixedOffsets()[it.index() + rankReduced]; 3081 newIndices.push_back(rewriter.create<arith::AddIOp>( 3082 xferOp->getLoc(), it.value(), 3083 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), 3084 offset))); 3085 } 3086 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3087 rewriter.replaceOpWithNewOp<TransferReadOp>( 3088 xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, 3089 xferOp.padding(), ArrayRef<bool>{inBounds}); 3090 3091 return success(); 3092 } 3093 }; 3094 } // namespace 3095 3096 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3097 MLIRContext *context) { 3098 results.add<FoldExtractSliceIntoTransferRead>(context); 3099 } 3100 3101 //===----------------------------------------------------------------------===// 3102 // TransferWriteOp 3103 //===----------------------------------------------------------------------===// 3104 3105 /// 1. Builder with type inference. 3106 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3107 Value vector, Value dest, ValueRange indices, 3108 AffineMapAttr permutationMapAttr, 3109 /*optional*/ Value mask, 3110 /*optional*/ ArrayAttr inBoundsAttr) { 3111 Type resultType = dest.getType().dyn_cast<RankedTensorType>(); 3112 build(builder, result, resultType, vector, dest, indices, permutationMapAttr, 3113 mask, inBoundsAttr); 3114 } 3115 3116 /// 2. Builder with type inference that sets an empty mask (variant with attrs). 3117 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3118 Value vector, Value dest, ValueRange indices, 3119 AffineMapAttr permutationMapAttr, 3120 /*optional*/ ArrayAttr inBoundsAttr) { 3121 build(builder, result, vector, dest, indices, permutationMapAttr, 3122 /*mask=*/Value(), inBoundsAttr); 3123 } 3124 3125 /// 3. Builder with type inference that sets an empty mask (variant without 3126 /// attrs) 3127 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3128 Value vector, Value dest, ValueRange indices, 3129 AffineMap permutationMap, 3130 Optional<ArrayRef<bool>> inBounds) { 3131 auto permutationMapAttr = AffineMapAttr::get(permutationMap); 3132 auto inBoundsAttr = (inBounds && !inBounds.getValue().empty()) 3133 ? builder.getBoolArrayAttr(inBounds.getValue()) 3134 : ArrayAttr(); 3135 build(builder, result, vector, dest, indices, permutationMapAttr, 3136 /*mask=*/Value(), inBoundsAttr); 3137 } 3138 3139 /// 4. Builder with type inference that sets an empty mask and sets permutation 3140 /// map to 'getMinorIdentityMap'. 3141 void TransferWriteOp::build(OpBuilder &builder, OperationState &result, 3142 Value vector, Value dest, ValueRange indices, 3143 Optional<ArrayRef<bool>> inBounds) { 3144 auto vectorType = vector.getType().cast<VectorType>(); 3145 AffineMap permutationMap = getTransferMinorIdentityMap( 3146 dest.getType().cast<ShapedType>(), vectorType); 3147 build(builder, result, vector, dest, indices, permutationMap, inBounds); 3148 } 3149 3150 static ParseResult parseTransferWriteOp(OpAsmParser &parser, 3151 OperationState &result) { 3152 auto &builder = parser.getBuilder(); 3153 SMLoc typesLoc; 3154 OpAsmParser::OperandType vectorInfo, sourceInfo; 3155 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 3156 SmallVector<Type, 2> types; 3157 OpAsmParser::OperandType maskInfo; 3158 if (parser.parseOperand(vectorInfo) || parser.parseComma() || 3159 parser.parseOperand(sourceInfo) || 3160 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) 3161 return failure(); 3162 ParseResult hasMask = parser.parseOptionalComma(); 3163 if (hasMask.succeeded() && parser.parseOperand(maskInfo)) 3164 return failure(); 3165 if (parser.parseOptionalAttrDict(result.attributes) || 3166 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) 3167 return failure(); 3168 if (types.size() != 2) 3169 return parser.emitError(typesLoc, "requires two types"); 3170 auto indexType = builder.getIndexType(); 3171 VectorType vectorType = types[0].dyn_cast<VectorType>(); 3172 if (!vectorType) 3173 return parser.emitError(typesLoc, "requires vector type"); 3174 ShapedType shapedType = types[1].dyn_cast<ShapedType>(); 3175 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) 3176 return parser.emitError(typesLoc, "requires memref or ranked tensor type"); 3177 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); 3178 auto attr = result.attributes.get(permutationAttrName); 3179 if (!attr) { 3180 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); 3181 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); 3182 } 3183 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || 3184 parser.resolveOperand(sourceInfo, shapedType, result.operands) || 3185 parser.resolveOperands(indexInfo, indexType, result.operands)) 3186 return failure(); 3187 if (hasMask.succeeded()) { 3188 if (shapedType.getElementType().dyn_cast<VectorType>()) 3189 return parser.emitError( 3190 maskInfo.location, "does not support masks with vector element type"); 3191 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); 3192 if (parser.resolveOperand(maskInfo, maskType, result.operands)) 3193 return failure(); 3194 } 3195 result.addAttribute( 3196 TransferWriteOp::getOperandSegmentSizeAttr(), 3197 builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()), 3198 static_cast<int32_t>(hasMask.succeeded())})); 3199 return failure(shapedType.isa<RankedTensorType>() && 3200 parser.addTypeToList(shapedType, result.types)); 3201 } 3202 3203 static void print(OpAsmPrinter &p, TransferWriteOp op) { 3204 p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]"; 3205 if (op.mask()) 3206 p << ", " << op.mask(); 3207 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation())); 3208 p << " : " << op.getVectorType() << ", " << op.getShapedType(); 3209 } 3210 3211 static LogicalResult verify(TransferWriteOp op) { 3212 // Consistency of elemental types in shape and vector. 3213 ShapedType shapedType = op.getShapedType(); 3214 VectorType vectorType = op.getVectorType(); 3215 VectorType maskType = op.getMaskType(); 3216 auto permutationMap = op.permutation_map(); 3217 3218 if (llvm::size(op.indices()) != shapedType.getRank()) 3219 return op.emitOpError("requires ") << shapedType.getRank() << " indices"; 3220 3221 // We do not allow broadcast dimensions on TransferWriteOps for the moment, 3222 // as the semantics is unclear. This can be revisited later if necessary. 3223 if (op.hasBroadcastDim()) 3224 return op.emitOpError("should not have broadcast dimensions"); 3225 3226 if (failed( 3227 verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()), 3228 shapedType, vectorType, maskType, permutationMap, 3229 op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) 3230 return failure(); 3231 3232 return verifyPermutationMap(permutationMap, 3233 [&op](Twine t) { return op.emitOpError(t); }); 3234 } 3235 3236 /// Fold: 3237 /// ``` 3238 /// %t1 = ... 3239 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : 3240 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3241 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} : 3242 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3243 /// ``` 3244 /// 3245 /// into: 3246 /// 3247 /// ``` 3248 /// %t0 3249 /// ``` 3250 /// 3251 /// The producer of t1 may or may not be DCE'd depending on whether it is a 3252 /// block argument or has side effects. 3253 static LogicalResult foldReadInitWrite(TransferWriteOp write, 3254 ArrayRef<Attribute>, 3255 SmallVectorImpl<OpFoldResult> &results) { 3256 // TODO: support 0-d corner case. 3257 if (write.getTransferRank() == 0) 3258 return failure(); 3259 auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>(); 3260 // If not operating on tensors, bail. 3261 if (!rankedTensorType) 3262 return failure(); 3263 // If no read, bail. 3264 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3265 if (!read) 3266 return failure(); 3267 // TODO: support 0-d corner case. 3268 if (read.getTransferRank() == 0) 3269 return failure(); 3270 // For now, only accept minor identity. Future: composition is minor identity. 3271 if (!read.permutation_map().isMinorIdentity() || 3272 !write.permutation_map().isMinorIdentity()) 3273 return failure(); 3274 // Bail on mismatching ranks. 3275 if (read.getTransferRank() != write.getTransferRank()) 3276 return failure(); 3277 // Bail on potential out-of-bounds accesses. 3278 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) 3279 return failure(); 3280 // Tensor types must be the same. 3281 if (read.source().getType() != rankedTensorType) 3282 return failure(); 3283 // Vector types must be the same. 3284 if (read.getVectorType() != write.getVectorType()) 3285 return failure(); 3286 // Vector and Tensor shapes must match. 3287 if (read.getVectorType().getShape() != rankedTensorType.getShape()) 3288 return failure(); 3289 // If any index is nonzero. 3290 auto isNotConstantZero = [](Value v) { 3291 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>(); 3292 return !cstOp || cstOp.value() != 0; 3293 }; 3294 if (llvm::any_of(read.indices(), isNotConstantZero) || 3295 llvm::any_of(write.indices(), isNotConstantZero)) 3296 return failure(); 3297 // Success. 3298 results.push_back(read.source()); 3299 return success(); 3300 } 3301 3302 static bool checkSameValueWAR(vector::TransferReadOp read, 3303 vector::TransferWriteOp write) { 3304 return read.source() == write.source() && read.indices() == write.indices() && 3305 read.permutation_map() == write.permutation_map() && 3306 read.getVectorType() == write.getVectorType() && !read.mask() && 3307 !write.mask(); 3308 } 3309 /// Fold transfer_write write after read: 3310 /// ``` 3311 /// %t0 = ... 3312 /// %v = vector.transfer_read %t0[%c0...] : 3313 /// tensor<static_sizesxf32>, vector<static_sizesxf32> 3314 /// %t1 = vector.transfer_write %v, %t0[%c0...] : 3315 /// vector<static_sizesxf32>, tensor<static_sizesxf32> 3316 /// ``` 3317 /// 3318 /// into: 3319 /// 3320 /// ``` 3321 /// %t0 3322 /// ``` 3323 static LogicalResult foldWAR(TransferWriteOp write, 3324 SmallVectorImpl<OpFoldResult> &results) { 3325 if (!write.source().getType().isa<RankedTensorType>()) 3326 return failure(); 3327 auto read = write.vector().getDefiningOp<vector::TransferReadOp>(); 3328 if (!read) 3329 return failure(); 3330 3331 if (!checkSameValueWAR(read, write)) 3332 return failure(); 3333 results.push_back(read.source()); 3334 return success(); 3335 } 3336 3337 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands, 3338 SmallVectorImpl<OpFoldResult> &results) { 3339 if (succeeded(foldReadInitWrite(*this, operands, results))) 3340 return success(); 3341 if (succeeded(foldWAR(*this, results))) 3342 return success(); 3343 if (succeeded(foldTransferInBoundsAttribute(*this))) 3344 return success(); 3345 return foldMemRefCast(*this); 3346 } 3347 3348 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { 3349 return llvm::to_vector<4>(getVectorType().getShape()); 3350 } 3351 3352 void TransferWriteOp::getEffects( 3353 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 3354 &effects) { 3355 if (getShapedType().isa<MemRefType>()) 3356 effects.emplace_back(MemoryEffects::Write::get(), source(), 3357 SideEffects::DefaultResource::get()); 3358 } 3359 3360 namespace { 3361 /// Remove dead transfer write from the SSA chain so that it an be eliminated by 3362 /// DCE 3363 /// ``` 3364 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3365 /// : vector<1x4xf32>, tensor<4x4xf32> 3366 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} 3367 /// : vector<1x4xf32>, tensor<4x4xf32> 3368 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3369 /// : vector<1x4xf32>, tensor<4x4xf32> 3370 /// ``` 3371 /// 3372 /// into: 3373 /// 3374 /// ``` 3375 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} 3376 /// : vector<1x4xf32>, tensor<4x4xf32> 3377 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]} 3378 /// : vector<1x4xf32>, tensor<4x4xf32> 3379 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} 3380 /// : vector<1x4xf32>, tensor<4x4xf32> 3381 /// ``` 3382 /// 3383 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have 3384 /// any other uses. 3385 class FoldWaw final : public OpRewritePattern<TransferWriteOp> { 3386 public: 3387 using OpRewritePattern<TransferWriteOp>::OpRewritePattern; 3388 LogicalResult matchAndRewrite(TransferWriteOp writeOp, 3389 PatternRewriter &rewriter) const override { 3390 if (!writeOp.getShapedType().isa<RankedTensorType>()) 3391 return failure(); 3392 vector::TransferWriteOp writeToModify = writeOp; 3393 3394 auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>(); 3395 while (defWrite) { 3396 if (checkSameValueWAW(writeOp, defWrite)) { 3397 writeToModify.sourceMutable().assign(defWrite.source()); 3398 return success(); 3399 } 3400 if (!isDisjointTransferIndices( 3401 cast<VectorTransferOpInterface>(defWrite.getOperation()), 3402 cast<VectorTransferOpInterface>(writeOp.getOperation()))) 3403 break; 3404 // If the previous write op doesn't have any other use we an safely look 3405 // at the previous store to see if it can be removed. 3406 if (!defWrite->hasOneUse()) 3407 break; 3408 writeToModify = defWrite; 3409 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>(); 3410 } 3411 return failure(); 3412 } 3413 }; 3414 3415 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write 3416 /// could directly write to the insert_slice's destination. E.g.: 3417 /// 3418 /// ``` 3419 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} 3420 /// : vector<4x5xf32>, tensor<4x5xf32> 3421 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] 3422 /// : tensor<4x5xf32> into tensor<?x?xf32> 3423 /// ``` 3424 /// is rewritten to: 3425 /// ``` 3426 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} 3427 /// : vector<4x5xf32>, tensor<?x?xf32> 3428 /// ``` 3429 struct FoldInsertSliceIntoTransferWrite 3430 : public OpRewritePattern<tensor::InsertSliceOp> { 3431 public: 3432 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 3433 3434 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 3435 PatternRewriter &rewriter) const override { 3436 if (!insertOp.hasUnitStride()) 3437 return failure(); 3438 3439 auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>(); 3440 if (!xferOp) 3441 return failure(); 3442 // TODO: support 0-d corner case. 3443 if (xferOp.getTransferRank() == 0) 3444 return failure(); 3445 3446 if (xferOp.hasOutOfBoundsDim()) 3447 return failure(); 3448 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) 3449 return failure(); 3450 if (xferOp.mask()) 3451 return failure(); 3452 // Fold only if the TransferWriteOp completely overwrites the `source` with 3453 // a vector. I.e., the result of the TransferWriteOp is a new tensor whose 3454 // content is the data of the vector. 3455 if (!llvm::equal(xferOp.getVectorType().getShape(), 3456 xferOp.getShapedType().getShape())) 3457 return failure(); 3458 if (!xferOp.permutation_map().isIdentity()) 3459 return failure(); 3460 3461 // Bail on illegal rank-reduction: we need to check that the rank-reduced 3462 // dims are exactly the leading dims. I.e. the following is illegal: 3463 // ``` 3464 // %0 = vector.transfer_write %v, %t[0,0], %cst : 3465 // vector<2x4xf32>, tensor<2x4xf32> 3466 // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : 3467 // tensor<2x4xf32> into tensor<2x1x4xf32> 3468 // ``` 3469 // 3470 // Cannot fold into: 3471 // ``` 3472 // %0 = vector.transfer_write %v, %t[0,0,0], %cst : 3473 // vector<2x4xf32>, tensor<2x1x4xf32> 3474 // ``` 3475 // For this, check the trailing `vectorRank` dims of the insert_slice result 3476 // tensor match the trailing dims of the inferred result tensor. 3477 int64_t rankReduced = 3478 insertOp.getType().getRank() - insertOp.getSourceType().getRank(); 3479 int64_t vectorRank = xferOp.getVectorType().getRank(); 3480 RankedTensorType inferredSourceTensorType = 3481 tensor::ExtractSliceOp::inferResultType( 3482 insertOp.getType(), insertOp.getMixedOffsets(), 3483 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 3484 auto actualSourceTensorShape = insertOp.getSourceType().getShape(); 3485 if (rankReduced > 0 && 3486 actualSourceTensorShape.take_back(vectorRank) != 3487 inferredSourceTensorType.getShape().take_back(vectorRank)) 3488 return failure(); 3489 3490 SmallVector<Value> indices = getValueOrCreateConstantIndexOp( 3491 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); 3492 SmallVector<bool> inBounds(xferOp.getTransferRank(), true); 3493 rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(), 3494 insertOp.dest(), indices, 3495 ArrayRef<bool>{inBounds}); 3496 return success(); 3497 } 3498 }; 3499 } // namespace 3500 3501 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, 3502 MLIRContext *context) { 3503 results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context); 3504 } 3505 3506 //===----------------------------------------------------------------------===// 3507 // LoadOp 3508 //===----------------------------------------------------------------------===// 3509 3510 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, 3511 MemRefType memRefTy) { 3512 if (!isLastMemrefDimUnitStride(memRefTy)) 3513 return op->emitOpError("most minor memref dim must have unit stride"); 3514 return success(); 3515 } 3516 3517 static LogicalResult verify(vector::LoadOp op) { 3518 VectorType resVecTy = op.getVectorType(); 3519 MemRefType memRefTy = op.getMemRefType(); 3520 3521 if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) 3522 return failure(); 3523 3524 // Checks for vector memrefs. 3525 Type memElemTy = memRefTy.getElementType(); 3526 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3527 if (memVecTy != resVecTy) 3528 return op.emitOpError("base memref and result vector types should match"); 3529 memElemTy = memVecTy.getElementType(); 3530 } 3531 3532 if (resVecTy.getElementType() != memElemTy) 3533 return op.emitOpError("base and result element types should match"); 3534 if (llvm::size(op.indices()) != memRefTy.getRank()) 3535 return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; 3536 return success(); 3537 } 3538 3539 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) { 3540 if (succeeded(foldMemRefCast(*this))) 3541 return getResult(); 3542 return OpFoldResult(); 3543 } 3544 3545 //===----------------------------------------------------------------------===// 3546 // StoreOp 3547 //===----------------------------------------------------------------------===// 3548 3549 static LogicalResult verify(vector::StoreOp op) { 3550 VectorType valueVecTy = op.getVectorType(); 3551 MemRefType memRefTy = op.getMemRefType(); 3552 3553 if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) 3554 return failure(); 3555 3556 // Checks for vector memrefs. 3557 Type memElemTy = memRefTy.getElementType(); 3558 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { 3559 if (memVecTy != valueVecTy) 3560 return op.emitOpError( 3561 "base memref and valueToStore vector types should match"); 3562 memElemTy = memVecTy.getElementType(); 3563 } 3564 3565 if (valueVecTy.getElementType() != memElemTy) 3566 return op.emitOpError("base and valueToStore element type should match"); 3567 if (llvm::size(op.indices()) != memRefTy.getRank()) 3568 return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; 3569 return success(); 3570 } 3571 3572 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands, 3573 SmallVectorImpl<OpFoldResult> &results) { 3574 return foldMemRefCast(*this); 3575 } 3576 3577 //===----------------------------------------------------------------------===// 3578 // MaskedLoadOp 3579 //===----------------------------------------------------------------------===// 3580 3581 static LogicalResult verify(MaskedLoadOp op) { 3582 VectorType maskVType = op.getMaskVectorType(); 3583 VectorType passVType = op.getPassThruVectorType(); 3584 VectorType resVType = op.getVectorType(); 3585 MemRefType memType = op.getMemRefType(); 3586 3587 if (resVType.getElementType() != memType.getElementType()) 3588 return op.emitOpError("base and result element type should match"); 3589 if (llvm::size(op.indices()) != memType.getRank()) 3590 return op.emitOpError("requires ") << memType.getRank() << " indices"; 3591 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3592 return op.emitOpError("expected result dim to match mask dim"); 3593 if (resVType != passVType) 3594 return op.emitOpError("expected pass_thru of same type as result type"); 3595 return success(); 3596 } 3597 3598 namespace { 3599 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { 3600 public: 3601 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern; 3602 LogicalResult matchAndRewrite(MaskedLoadOp load, 3603 PatternRewriter &rewriter) const override { 3604 switch (get1DMaskFormat(load.mask())) { 3605 case MaskFormat::AllTrue: 3606 rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(), 3607 load.base(), load.indices()); 3608 return success(); 3609 case MaskFormat::AllFalse: 3610 rewriter.replaceOp(load, load.pass_thru()); 3611 return success(); 3612 case MaskFormat::Unknown: 3613 return failure(); 3614 } 3615 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad"); 3616 } 3617 }; 3618 } // namespace 3619 3620 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3621 MLIRContext *context) { 3622 results.add<MaskedLoadFolder>(context); 3623 } 3624 3625 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) { 3626 if (succeeded(foldMemRefCast(*this))) 3627 return getResult(); 3628 return OpFoldResult(); 3629 } 3630 3631 //===----------------------------------------------------------------------===// 3632 // MaskedStoreOp 3633 //===----------------------------------------------------------------------===// 3634 3635 static LogicalResult verify(MaskedStoreOp op) { 3636 VectorType maskVType = op.getMaskVectorType(); 3637 VectorType valueVType = op.getVectorType(); 3638 MemRefType memType = op.getMemRefType(); 3639 3640 if (valueVType.getElementType() != memType.getElementType()) 3641 return op.emitOpError("base and valueToStore element type should match"); 3642 if (llvm::size(op.indices()) != memType.getRank()) 3643 return op.emitOpError("requires ") << memType.getRank() << " indices"; 3644 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3645 return op.emitOpError("expected valueToStore dim to match mask dim"); 3646 return success(); 3647 } 3648 3649 namespace { 3650 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { 3651 public: 3652 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern; 3653 LogicalResult matchAndRewrite(MaskedStoreOp store, 3654 PatternRewriter &rewriter) const override { 3655 switch (get1DMaskFormat(store.mask())) { 3656 case MaskFormat::AllTrue: 3657 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3658 store, store.valueToStore(), store.base(), store.indices()); 3659 return success(); 3660 case MaskFormat::AllFalse: 3661 rewriter.eraseOp(store); 3662 return success(); 3663 case MaskFormat::Unknown: 3664 return failure(); 3665 } 3666 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore"); 3667 } 3668 }; 3669 } // namespace 3670 3671 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3672 MLIRContext *context) { 3673 results.add<MaskedStoreFolder>(context); 3674 } 3675 3676 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands, 3677 SmallVectorImpl<OpFoldResult> &results) { 3678 return foldMemRefCast(*this); 3679 } 3680 3681 //===----------------------------------------------------------------------===// 3682 // GatherOp 3683 //===----------------------------------------------------------------------===// 3684 3685 static LogicalResult verify(GatherOp op) { 3686 VectorType indVType = op.getIndexVectorType(); 3687 VectorType maskVType = op.getMaskVectorType(); 3688 VectorType resVType = op.getVectorType(); 3689 MemRefType memType = op.getMemRefType(); 3690 3691 if (resVType.getElementType() != memType.getElementType()) 3692 return op.emitOpError("base and result element type should match"); 3693 if (llvm::size(op.indices()) != memType.getRank()) 3694 return op.emitOpError("requires ") << memType.getRank() << " indices"; 3695 if (resVType.getDimSize(0) != indVType.getDimSize(0)) 3696 return op.emitOpError("expected result dim to match indices dim"); 3697 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3698 return op.emitOpError("expected result dim to match mask dim"); 3699 if (resVType != op.getPassThruVectorType()) 3700 return op.emitOpError("expected pass_thru of same type as result type"); 3701 return success(); 3702 } 3703 3704 namespace { 3705 class GatherFolder final : public OpRewritePattern<GatherOp> { 3706 public: 3707 using OpRewritePattern<GatherOp>::OpRewritePattern; 3708 LogicalResult matchAndRewrite(GatherOp gather, 3709 PatternRewriter &rewriter) const override { 3710 switch (get1DMaskFormat(gather.mask())) { 3711 case MaskFormat::AllTrue: 3712 return failure(); // no unmasked equivalent 3713 case MaskFormat::AllFalse: 3714 rewriter.replaceOp(gather, gather.pass_thru()); 3715 return success(); 3716 case MaskFormat::Unknown: 3717 return failure(); 3718 } 3719 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); 3720 } 3721 }; 3722 } // namespace 3723 3724 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, 3725 MLIRContext *context) { 3726 results.add<GatherFolder>(context); 3727 } 3728 3729 //===----------------------------------------------------------------------===// 3730 // ScatterOp 3731 //===----------------------------------------------------------------------===// 3732 3733 static LogicalResult verify(ScatterOp op) { 3734 VectorType indVType = op.getIndexVectorType(); 3735 VectorType maskVType = op.getMaskVectorType(); 3736 VectorType valueVType = op.getVectorType(); 3737 MemRefType memType = op.getMemRefType(); 3738 3739 if (valueVType.getElementType() != memType.getElementType()) 3740 return op.emitOpError("base and valueToStore element type should match"); 3741 if (llvm::size(op.indices()) != memType.getRank()) 3742 return op.emitOpError("requires ") << memType.getRank() << " indices"; 3743 if (valueVType.getDimSize(0) != indVType.getDimSize(0)) 3744 return op.emitOpError("expected valueToStore dim to match indices dim"); 3745 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3746 return op.emitOpError("expected valueToStore dim to match mask dim"); 3747 return success(); 3748 } 3749 3750 namespace { 3751 class ScatterFolder final : public OpRewritePattern<ScatterOp> { 3752 public: 3753 using OpRewritePattern<ScatterOp>::OpRewritePattern; 3754 LogicalResult matchAndRewrite(ScatterOp scatter, 3755 PatternRewriter &rewriter) const override { 3756 switch (get1DMaskFormat(scatter.mask())) { 3757 case MaskFormat::AllTrue: 3758 return failure(); // no unmasked equivalent 3759 case MaskFormat::AllFalse: 3760 rewriter.eraseOp(scatter); 3761 return success(); 3762 case MaskFormat::Unknown: 3763 return failure(); 3764 } 3765 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); 3766 } 3767 }; 3768 } // namespace 3769 3770 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, 3771 MLIRContext *context) { 3772 results.add<ScatterFolder>(context); 3773 } 3774 3775 //===----------------------------------------------------------------------===// 3776 // ExpandLoadOp 3777 //===----------------------------------------------------------------------===// 3778 3779 static LogicalResult verify(ExpandLoadOp op) { 3780 VectorType maskVType = op.getMaskVectorType(); 3781 VectorType passVType = op.getPassThruVectorType(); 3782 VectorType resVType = op.getVectorType(); 3783 MemRefType memType = op.getMemRefType(); 3784 3785 if (resVType.getElementType() != memType.getElementType()) 3786 return op.emitOpError("base and result element type should match"); 3787 if (llvm::size(op.indices()) != memType.getRank()) 3788 return op.emitOpError("requires ") << memType.getRank() << " indices"; 3789 if (resVType.getDimSize(0) != maskVType.getDimSize(0)) 3790 return op.emitOpError("expected result dim to match mask dim"); 3791 if (resVType != passVType) 3792 return op.emitOpError("expected pass_thru of same type as result type"); 3793 return success(); 3794 } 3795 3796 namespace { 3797 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { 3798 public: 3799 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern; 3800 LogicalResult matchAndRewrite(ExpandLoadOp expand, 3801 PatternRewriter &rewriter) const override { 3802 switch (get1DMaskFormat(expand.mask())) { 3803 case MaskFormat::AllTrue: 3804 rewriter.replaceOpWithNewOp<vector::LoadOp>( 3805 expand, expand.getType(), expand.base(), expand.indices()); 3806 return success(); 3807 case MaskFormat::AllFalse: 3808 rewriter.replaceOp(expand, expand.pass_thru()); 3809 return success(); 3810 case MaskFormat::Unknown: 3811 return failure(); 3812 } 3813 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder"); 3814 } 3815 }; 3816 } // namespace 3817 3818 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3819 MLIRContext *context) { 3820 results.add<ExpandLoadFolder>(context); 3821 } 3822 3823 //===----------------------------------------------------------------------===// 3824 // CompressStoreOp 3825 //===----------------------------------------------------------------------===// 3826 3827 static LogicalResult verify(CompressStoreOp op) { 3828 VectorType maskVType = op.getMaskVectorType(); 3829 VectorType valueVType = op.getVectorType(); 3830 MemRefType memType = op.getMemRefType(); 3831 3832 if (valueVType.getElementType() != memType.getElementType()) 3833 return op.emitOpError("base and valueToStore element type should match"); 3834 if (llvm::size(op.indices()) != memType.getRank()) 3835 return op.emitOpError("requires ") << memType.getRank() << " indices"; 3836 if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) 3837 return op.emitOpError("expected valueToStore dim to match mask dim"); 3838 return success(); 3839 } 3840 3841 namespace { 3842 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { 3843 public: 3844 using OpRewritePattern<CompressStoreOp>::OpRewritePattern; 3845 LogicalResult matchAndRewrite(CompressStoreOp compress, 3846 PatternRewriter &rewriter) const override { 3847 switch (get1DMaskFormat(compress.mask())) { 3848 case MaskFormat::AllTrue: 3849 rewriter.replaceOpWithNewOp<vector::StoreOp>( 3850 compress, compress.valueToStore(), compress.base(), 3851 compress.indices()); 3852 return success(); 3853 case MaskFormat::AllFalse: 3854 rewriter.eraseOp(compress); 3855 return success(); 3856 case MaskFormat::Unknown: 3857 return failure(); 3858 } 3859 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder"); 3860 } 3861 }; 3862 } // namespace 3863 3864 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 3865 MLIRContext *context) { 3866 results.add<CompressStoreFolder>(context); 3867 } 3868 3869 //===----------------------------------------------------------------------===// 3870 // ShapeCastOp 3871 //===----------------------------------------------------------------------===// 3872 3873 /// Returns true if each element of 'a' is equal to the product of a contiguous 3874 /// sequence of the elements of 'b'. Returns false otherwise. 3875 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 3876 unsigned rankA = a.size(); 3877 unsigned rankB = b.size(); 3878 assert(rankA < rankB); 3879 3880 unsigned i = 0; 3881 unsigned j = 0; 3882 while (i < rankA && j < rankB) { 3883 int64_t dimA = a[i]; 3884 int64_t dimB = 1; 3885 while (dimB < dimA && j < rankB) 3886 dimB *= b[j++]; 3887 if (dimA != dimB) 3888 break; 3889 ++i; 3890 3891 // Handle the case when trailing dimensions are of size 1. 3892 // Include them into the contiguous sequence. 3893 auto isOne = [](int64_t v) { return v == 1; }; 3894 if (i < rankA && llvm::all_of(a.slice(i), isOne)) 3895 i = rankA; 3896 if (j < rankB && llvm::all_of(b.slice(j), isOne)) 3897 j = rankB; 3898 } 3899 3900 return i == rankA && j == rankB; 3901 } 3902 3903 static LogicalResult verifyVectorShapeCast(Operation *op, 3904 VectorType sourceVectorType, 3905 VectorType resultVectorType) { 3906 // Check that element type is the same. 3907 if (sourceVectorType.getElementType() != resultVectorType.getElementType()) 3908 return op->emitOpError("source/result vectors must have same element type"); 3909 auto sourceShape = sourceVectorType.getShape(); 3910 auto resultShape = resultVectorType.getShape(); 3911 3912 // Check that product of source dim sizes matches product of result dim sizes. 3913 int64_t sourceDimProduct = std::accumulate( 3914 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); 3915 int64_t resultDimProduct = std::accumulate( 3916 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); 3917 if (sourceDimProduct != resultDimProduct) 3918 return op->emitOpError("source/result number of elements must match"); 3919 3920 // Check that expanding/contracting rank cases. 3921 unsigned sourceRank = sourceVectorType.getRank(); 3922 unsigned resultRank = resultVectorType.getRank(); 3923 if (sourceRank < resultRank) { 3924 if (!isValidShapeCast(sourceShape, resultShape)) 3925 return op->emitOpError("invalid shape cast"); 3926 } else if (sourceRank > resultRank) { 3927 if (!isValidShapeCast(resultShape, sourceShape)) 3928 return op->emitOpError("invalid shape cast"); 3929 } 3930 return success(); 3931 } 3932 3933 static LogicalResult verify(ShapeCastOp op) { 3934 auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>(); 3935 auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>(); 3936 3937 // Check if source/result are of vector type. 3938 if (sourceVectorType && resultVectorType) 3939 return verifyVectorShapeCast(op, sourceVectorType, resultVectorType); 3940 3941 return success(); 3942 } 3943 3944 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { 3945 // Nop shape cast. 3946 if (source().getType() == result().getType()) 3947 return source(); 3948 3949 // Canceling shape casts. 3950 if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) { 3951 if (result().getType() == otherOp.source().getType()) 3952 return otherOp.source(); 3953 3954 // Only allows valid transitive folding. 3955 VectorType srcType = otherOp.source().getType().cast<VectorType>(); 3956 VectorType resultType = getResult().getType().cast<VectorType>(); 3957 if (srcType.getRank() < resultType.getRank()) { 3958 if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) 3959 return {}; 3960 } else if (srcType.getRank() > resultType.getRank()) { 3961 if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) 3962 return {}; 3963 } else { 3964 return {}; 3965 } 3966 3967 setOperand(otherOp.source()); 3968 return getResult(); 3969 } 3970 return {}; 3971 } 3972 3973 namespace { 3974 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. 3975 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { 3976 public: 3977 using OpRewritePattern<ShapeCastOp>::OpRewritePattern; 3978 3979 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, 3980 PatternRewriter &rewriter) const override { 3981 auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>(); 3982 if (!constantOp) 3983 return failure(); 3984 // Only handle splat for now. 3985 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 3986 if (!dense) 3987 return failure(); 3988 auto newAttr = 3989 DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(), 3990 dense.getSplatValue<Attribute>()); 3991 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); 3992 return success(); 3993 } 3994 }; 3995 3996 } // namespace 3997 3998 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 3999 MLIRContext *context) { 4000 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. 4001 results.add<ShapeCastConstantFolder>(context); 4002 } 4003 4004 //===----------------------------------------------------------------------===// 4005 // VectorBitCastOp 4006 //===----------------------------------------------------------------------===// 4007 4008 static LogicalResult verify(BitCastOp op) { 4009 auto sourceVectorType = op.getSourceVectorType(); 4010 auto resultVectorType = op.getResultVectorType(); 4011 4012 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { 4013 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) 4014 return op.emitOpError("dimension size mismatch at: ") << i; 4015 } 4016 4017 DataLayout dataLayout = DataLayout::closest(op); 4018 auto sourceElementBits = 4019 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); 4020 auto resultElementBits = 4021 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); 4022 4023 if (sourceVectorType.getRank() == 0) { 4024 if (sourceElementBits != resultElementBits) 4025 return op.emitOpError("source/result bitwidth of the 0-D vector element " 4026 "types must be equal"); 4027 } else if (sourceElementBits * sourceVectorType.getShape().back() != 4028 resultElementBits * resultVectorType.getShape().back()) { 4029 return op.emitOpError( 4030 "source/result bitwidth of the minor 1-D vectors must be equal"); 4031 } 4032 4033 return success(); 4034 } 4035 4036 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) { 4037 // Nop cast. 4038 if (source().getType() == result().getType()) 4039 return source(); 4040 4041 // Canceling bitcasts. 4042 if (auto otherOp = source().getDefiningOp<BitCastOp>()) 4043 if (result().getType() == otherOp.source().getType()) 4044 return otherOp.source(); 4045 4046 Attribute sourceConstant = operands.front(); 4047 if (!sourceConstant) 4048 return {}; 4049 4050 Type srcElemType = getSourceVectorType().getElementType(); 4051 Type dstElemType = getResultVectorType().getElementType(); 4052 4053 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) { 4054 if (floatPack.isSplat()) { 4055 auto splat = floatPack.getSplatValue<FloatAttr>(); 4056 4057 // Casting fp16 into fp32. 4058 if (srcElemType.isF16() && dstElemType.isF32()) { 4059 uint32_t bits = static_cast<uint32_t>( 4060 splat.getValue().bitcastToAPInt().getZExtValue()); 4061 // Duplicate the 16-bit pattern. 4062 bits = (bits << 16) | (bits & 0xffff); 4063 APInt intBits(32, bits); 4064 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); 4065 return DenseElementsAttr::get(getResultVectorType(), floatBits); 4066 } 4067 } 4068 } 4069 4070 return {}; 4071 } 4072 4073 //===----------------------------------------------------------------------===// 4074 // TypeCastOp 4075 //===----------------------------------------------------------------------===// 4076 4077 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { 4078 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>(); 4079 SmallVector<int64_t, 8> res(memRefType.getShape().begin(), 4080 memRefType.getShape().end()); 4081 if (vectorType) 4082 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); 4083 return res; 4084 } 4085 4086 /// Build the canonical memRefType with a single vector. 4087 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. 4088 void TypeCastOp::build(OpBuilder &builder, OperationState &result, 4089 Value source) { 4090 result.addOperands(source); 4091 MemRefType memRefType = source.getType().cast<MemRefType>(); 4092 VectorType vectorType = 4093 VectorType::get(extractShape(memRefType), 4094 getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); 4095 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), 4096 memRefType.getMemorySpace())); 4097 } 4098 4099 static LogicalResult verify(TypeCastOp op) { 4100 MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); 4101 if (!canonicalType.getLayout().isIdentity()) 4102 return op.emitOpError( 4103 "expects operand to be a memref with identity layout"); 4104 if (!op.getResultMemRefType().getLayout().isIdentity()) 4105 return op.emitOpError("expects result to be a memref with identity layout"); 4106 if (op.getResultMemRefType().getMemorySpace() != 4107 op.getMemRefType().getMemorySpace()) 4108 return op.emitOpError("expects result in same memory space"); 4109 4110 auto sourceType = op.getMemRefType(); 4111 auto resultType = op.getResultMemRefType(); 4112 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != 4113 getElementTypeOrSelf(getElementTypeOrSelf(resultType))) 4114 return op.emitOpError( 4115 "expects result and operand with same underlying scalar type: ") 4116 << resultType; 4117 if (extractShape(sourceType) != extractShape(resultType)) 4118 return op.emitOpError( 4119 "expects concatenated result and operand shapes to be equal: ") 4120 << resultType; 4121 return success(); 4122 } 4123 4124 //===----------------------------------------------------------------------===// 4125 // TransposeOp 4126 //===----------------------------------------------------------------------===// 4127 4128 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, 4129 Value vector, ArrayRef<int64_t> transp) { 4130 VectorType vt = vector.getType().cast<VectorType>(); 4131 SmallVector<int64_t, 4> transposedShape(vt.getRank()); 4132 for (unsigned i = 0; i < transp.size(); ++i) 4133 transposedShape[i] = vt.getShape()[transp[i]]; 4134 4135 result.addOperands(vector); 4136 result.addTypes(VectorType::get(transposedShape, vt.getElementType())); 4137 result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); 4138 } 4139 4140 // Eliminates transpose operations, which produce values identical to their 4141 // input values. This happens when the dimensions of the input vector remain in 4142 // their original order after the transpose operation. 4143 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) { 4144 SmallVector<int64_t, 4> transp; 4145 getTransp(transp); 4146 4147 // Check if the permutation of the dimensions contains sequential values: 4148 // {0, 1, 2, ...}. 4149 for (int64_t i = 0, e = transp.size(); i < e; i++) { 4150 if (transp[i] != i) 4151 return {}; 4152 } 4153 4154 return vector(); 4155 } 4156 4157 static LogicalResult verify(vector::TransposeOp op) { 4158 VectorType vectorType = op.getVectorType(); 4159 VectorType resultType = op.getResultType(); 4160 int64_t rank = resultType.getRank(); 4161 if (vectorType.getRank() != rank) 4162 return op.emitOpError("vector result rank mismatch: ") << rank; 4163 // Verify transposition array. 4164 auto transpAttr = op.transp().getValue(); 4165 int64_t size = transpAttr.size(); 4166 if (rank != size) 4167 return op.emitOpError("transposition length mismatch: ") << size; 4168 SmallVector<bool, 8> seen(rank, false); 4169 for (const auto &ta : llvm::enumerate(transpAttr)) { 4170 int64_t i = ta.value().cast<IntegerAttr>().getInt(); 4171 if (i < 0 || i >= rank) 4172 return op.emitOpError("transposition index out of range: ") << i; 4173 if (seen[i]) 4174 return op.emitOpError("duplicate position index: ") << i; 4175 seen[i] = true; 4176 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) 4177 return op.emitOpError("dimension size mismatch at: ") << i; 4178 } 4179 return success(); 4180 } 4181 4182 namespace { 4183 4184 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. 4185 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { 4186 public: 4187 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 4188 4189 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 4190 PatternRewriter &rewriter) const override { 4191 // Wrapper around vector::TransposeOp::getTransp() for cleaner code. 4192 auto getPermutation = [](vector::TransposeOp transpose) { 4193 SmallVector<int64_t, 4> permutation; 4194 transpose.getTransp(permutation); 4195 return permutation; 4196 }; 4197 4198 // Composes two permutations: result[i] = permutation1[permutation2[i]]. 4199 auto composePermutations = [](ArrayRef<int64_t> permutation1, 4200 ArrayRef<int64_t> permutation2) { 4201 SmallVector<int64_t, 4> result; 4202 for (auto index : permutation2) 4203 result.push_back(permutation1[index]); 4204 return result; 4205 }; 4206 4207 // Return if the input of 'transposeOp' is not defined by another transpose. 4208 vector::TransposeOp parentTransposeOp = 4209 transposeOp.vector().getDefiningOp<vector::TransposeOp>(); 4210 if (!parentTransposeOp) 4211 return failure(); 4212 4213 SmallVector<int64_t, 4> permutation = composePermutations( 4214 getPermutation(parentTransposeOp), getPermutation(transposeOp)); 4215 // Replace 'transposeOp' with a new transpose operation. 4216 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 4217 transposeOp, transposeOp.getResult().getType(), 4218 parentTransposeOp.vector(), 4219 vector::getVectorSubscriptAttr(rewriter, permutation)); 4220 return success(); 4221 } 4222 }; 4223 4224 } // namespace 4225 4226 void vector::TransposeOp::getCanonicalizationPatterns( 4227 RewritePatternSet &results, MLIRContext *context) { 4228 results.add<TransposeFolder>(context); 4229 } 4230 4231 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) { 4232 populateFromInt64AttrArray(transp(), results); 4233 } 4234 4235 //===----------------------------------------------------------------------===// 4236 // ConstantMaskOp 4237 //===----------------------------------------------------------------------===// 4238 4239 static LogicalResult verify(ConstantMaskOp &op) { 4240 auto resultType = op.getResult().getType().cast<VectorType>(); 4241 // Check the corner case of 0-D vectors first. 4242 if (resultType.getRank() == 0) { 4243 if (op.mask_dim_sizes().size() != 1) 4244 return op->emitError("array attr must have length 1 for 0-D vectors"); 4245 auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt(); 4246 if (dim != 0 && dim != 1) 4247 return op->emitError( 4248 "mask dim size must be either 0 or 1 for 0-D vectors"); 4249 return success(); 4250 } 4251 4252 // Verify that array attr size matches the rank of the vector result. 4253 if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank()) 4254 return op.emitOpError( 4255 "must specify array attr of size equal vector result rank"); 4256 // Verify that each array attr element is in bounds of corresponding vector 4257 // result dimension size. 4258 auto resultShape = resultType.getShape(); 4259 SmallVector<int64_t, 4> maskDimSizes; 4260 for (const auto &it : llvm::enumerate(op.mask_dim_sizes())) { 4261 int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); 4262 if (attrValue < 0 || attrValue > resultShape[it.index()]) 4263 return op.emitOpError( 4264 "array attr of size out of bounds of vector result dimension size"); 4265 maskDimSizes.push_back(attrValue); 4266 } 4267 // Verify that if one mask dim size is zero, they all should be zero (because 4268 // the mask region is a conjunction of each mask dimension interval). 4269 bool anyZeros = llvm::is_contained(maskDimSizes, 0); 4270 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); 4271 if (anyZeros && !allZeros) 4272 return op.emitOpError("expected all mask dim sizes to be zeros, " 4273 "as a result of conjunction with zero mask dim"); 4274 return success(); 4275 } 4276 4277 //===----------------------------------------------------------------------===// 4278 // CreateMaskOp 4279 //===----------------------------------------------------------------------===// 4280 4281 static LogicalResult verify(CreateMaskOp op) { 4282 auto vectorType = op.getResult().getType().cast<VectorType>(); 4283 // Verify that an operand was specified for each result vector each dimension. 4284 if (vectorType.getRank() == 0) { 4285 if (op->getNumOperands() != 1) 4286 return op.emitOpError( 4287 "must specify exactly one operand for 0-D create_mask"); 4288 } else if (op.getNumOperands() != 4289 op.getResult().getType().cast<VectorType>().getRank()) { 4290 return op.emitOpError( 4291 "must specify an operand for each result vector dimension"); 4292 } 4293 return success(); 4294 } 4295 4296 namespace { 4297 4298 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. 4299 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { 4300 public: 4301 using OpRewritePattern<CreateMaskOp>::OpRewritePattern; 4302 4303 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, 4304 PatternRewriter &rewriter) const override { 4305 // Return if any of 'createMaskOp' operands are not defined by a constant. 4306 auto isNotDefByConstant = [](Value operand) { 4307 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); 4308 }; 4309 if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant)) 4310 return failure(); 4311 // Gather constant mask dimension sizes. 4312 SmallVector<int64_t, 4> maskDimSizes; 4313 for (auto it : llvm::zip(createMaskOp.operands(), 4314 createMaskOp.getType().getShape())) { 4315 auto *defOp = std::get<0>(it).getDefiningOp(); 4316 int64_t maxDimSize = std::get<1>(it); 4317 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value(); 4318 dimSize = std::min(dimSize, maxDimSize); 4319 // If one of dim sizes is zero, set all dims to zero. 4320 if (dimSize <= 0) { 4321 maskDimSizes.assign(createMaskOp.getType().getRank(), 0); 4322 break; 4323 } 4324 maskDimSizes.push_back(dimSize); 4325 } 4326 // Replace 'createMaskOp' with ConstantMaskOp. 4327 rewriter.replaceOpWithNewOp<ConstantMaskOp>( 4328 createMaskOp, createMaskOp.getResult().getType(), 4329 vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); 4330 return success(); 4331 } 4332 }; 4333 4334 } // namespace 4335 4336 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, 4337 MLIRContext *context) { 4338 results.add<CreateMaskFolder>(context); 4339 } 4340 4341 //===----------------------------------------------------------------------===// 4342 // ScanOp 4343 //===----------------------------------------------------------------------===// 4344 4345 static LogicalResult verify(ScanOp op) { 4346 VectorType srcType = op.getSourceType(); 4347 VectorType initialType = op.getInitialValueType(); 4348 // Check reduction dimension < rank. 4349 int64_t srcRank = srcType.getRank(); 4350 int64_t reductionDim = op.reduction_dim(); 4351 if (reductionDim >= srcRank) 4352 return op.emitOpError("reduction dimension ") 4353 << reductionDim << " has to be less than " << srcRank; 4354 4355 // Check that rank(initial_value) = rank(src) - 1. 4356 int64_t initialValueRank = initialType.getRank(); 4357 if (initialValueRank != srcRank - 1) 4358 return op.emitOpError("initial value rank ") 4359 << initialValueRank << " has to be equal to " << srcRank - 1; 4360 4361 // Check shapes of initial value and src. 4362 ArrayRef<int64_t> srcShape = srcType.getShape(); 4363 ArrayRef<int64_t> initialValueShapes = initialType.getShape(); 4364 SmallVector<int64_t> expectedShape; 4365 for (int i = 0; i < srcRank; i++) { 4366 if (i != reductionDim) 4367 expectedShape.push_back(srcShape[i]); 4368 } 4369 if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape), 4370 [](std::tuple<int64_t, int64_t> s) { 4371 return std::get<0>(s) != std::get<1>(s); 4372 })) { 4373 return op.emitOpError("incompatible input/initial value shapes"); 4374 } 4375 4376 return success(); 4377 } 4378 4379 void mlir::vector::populateVectorToVectorCanonicalizationPatterns( 4380 RewritePatternSet &patterns) { 4381 patterns 4382 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, 4383 ScatterFolder, ExpandLoadFolder, CompressStoreFolder, 4384 StridedSliceConstantMaskFolder, TransposeFolder>( 4385 patterns.getContext()); 4386 } 4387 4388 #define GET_OP_CLASSES 4389 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 4390