1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 10 #include "mlir/Dialect/Tensor/IR/Tensor.h" 11 #include "mlir/Dialect/Utils/StaticValueUtils.h" 12 #include "mlir/IR/BlockAndValueMapping.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "llvm/ADT/STLExtras.h" 18 19 using namespace mlir; 20 using namespace mlir::tensor; 21 22 /// Materialize a single constant operation from a given attribute value with 23 /// the desired resultant type. 24 Operation *TensorDialect::materializeConstant(OpBuilder &builder, 25 Attribute value, Type type, 26 Location loc) { 27 return builder.create<mlir::ConstantOp>(loc, type, value); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // CastOp 32 //===----------------------------------------------------------------------===// 33 34 /// Returns true if `target` is a ranked tensor type that preserves static 35 /// information available in the `source` ranked tensor type. 36 bool mlir::tensor::preservesStaticInformation(Type source, Type target) { 37 auto sourceType = source.dyn_cast<RankedTensorType>(); 38 auto targetType = target.dyn_cast<RankedTensorType>(); 39 40 // Requires RankedTensorType. 41 if (!sourceType || !targetType) 42 return false; 43 44 // Requires same elemental type. 45 if (sourceType.getElementType() != targetType.getElementType()) 46 return false; 47 48 // Requires same rank. 49 if (sourceType.getRank() != targetType.getRank()) 50 return false; 51 52 // If cast is towards more static sizes along any dimension, don't fold. 53 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { 54 if (!ShapedType::isDynamic(std::get<0>(t)) && 55 ShapedType::isDynamic(std::get<1>(t))) 56 return false; 57 } 58 59 return true; 60 } 61 62 /// Determines whether tensor::CastOp casts to a more dynamic version of the 63 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 64 /// implement canonicalization patterns for ops in different dialects that may 65 /// consume the results of tensor.cast operations. Such foldable tensor.cast 66 /// operations are typically inserted as `slice` ops and are canonicalized, 67 /// to preserve the type compatibility of their uses. 68 /// 69 /// Returns true when all conditions are met: 70 /// 1. source and result are ranked tensors with same element type and rank. 71 /// 2. the tensor type has more static information than the result 72 /// 73 /// Example: 74 /// ```mlir 75 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 76 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 77 /// ``` 78 /// 79 /// folds into: 80 /// 81 /// ```mlir 82 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 83 /// ``` 84 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 85 if (!castOp) 86 return false; 87 88 // Can fold if the source of cast has at least as much static information as 89 // its results. 90 return preservesStaticInformation(castOp.getType(), 91 castOp.source().getType()); 92 } 93 94 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp 95 /// that can be folded. 96 LogicalResult mlir::tensor::foldTensorCast(Operation *op) { 97 bool folded = false; 98 for (OpOperand &operand : op->getOpOperands()) { 99 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 100 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 101 operand.set(castOp.getOperand()); 102 folded = true; 103 } 104 } 105 return success(folded); 106 } 107 108 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 109 if (inputs.size() != 1 || outputs.size() != 1) 110 return false; 111 Type a = inputs.front(), b = outputs.front(); 112 auto aT = a.dyn_cast<TensorType>(); 113 auto bT = b.dyn_cast<TensorType>(); 114 if (!aT || !bT) 115 return false; 116 117 if (aT.getElementType() != bT.getElementType()) 118 return false; 119 120 return succeeded(verifyCompatibleShape(aT, bT)); 121 } 122 123 /// Compute a TensorType that has the joined shape knowledge of the two 124 /// given TensorTypes. The element types need to match. 125 static TensorType joinShapes(TensorType one, TensorType two) { 126 assert(one.getElementType() == two.getElementType()); 127 128 if (!one.hasRank()) 129 return two; 130 if (!two.hasRank()) 131 return one; 132 133 int64_t rank = one.getRank(); 134 if (rank != two.getRank()) 135 return {}; 136 137 SmallVector<int64_t, 4> join; 138 join.reserve(rank); 139 for (int64_t i = 0; i < rank; ++i) { 140 if (one.isDynamicDim(i)) { 141 join.push_back(two.getDimSize(i)); 142 continue; 143 } 144 if (two.isDynamicDim(i)) { 145 join.push_back(one.getDimSize(i)); 146 continue; 147 } 148 if (one.getDimSize(i) != two.getDimSize(i)) 149 return {}; 150 join.push_back(one.getDimSize(i)); 151 } 152 return RankedTensorType::get(join, one.getElementType()); 153 } 154 155 namespace { 156 157 /// Replaces chains of two tensor.cast operations by a single tensor.cast 158 /// operation if doing so does not remove runtime constraints. 159 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 160 using OpRewritePattern<CastOp>::OpRewritePattern; 161 162 LogicalResult matchAndRewrite(CastOp tensorCast, 163 PatternRewriter &rewriter) const final { 164 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 165 166 if (!tensorCastOperand) 167 return failure(); 168 169 auto sourceType = 170 tensorCastOperand.getOperand().getType().cast<TensorType>(); 171 auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); 172 auto resultType = tensorCast.getType().cast<TensorType>(); 173 174 // We can remove the intermediate cast if joining all three produces the 175 // same result as just joining the source and result shapes. 176 auto firstJoin = 177 joinShapes(joinShapes(sourceType, intermediateType), resultType); 178 179 // The join might not exist if the cast sequence would fail at runtime. 180 if (!firstJoin) 181 return failure(); 182 183 // The newJoin always exists if the above join exists, it might just contain 184 // less information. If so, we cannot drop the intermediate cast, as doing 185 // so would remove runtime checks. 186 auto newJoin = joinShapes(sourceType, resultType); 187 if (firstJoin != newJoin) 188 return failure(); 189 190 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 191 tensorCastOperand.getOperand()); 192 return success(); 193 } 194 }; 195 196 } // namespace 197 198 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, 199 MLIRContext *context) { 200 results.add<ChainedTensorCast>(context); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // DimOp 205 //===----------------------------------------------------------------------===// 206 207 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 208 int64_t index) { 209 auto loc = result.location; 210 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 211 build(builder, result, source, indexValue); 212 } 213 214 Optional<int64_t> DimOp::getConstantIndex() { 215 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 216 return constantOp.getValue().cast<IntegerAttr>().getInt(); 217 return {}; 218 } 219 220 static LogicalResult verify(DimOp op) { 221 // Assume unknown index to be in range. 222 Optional<int64_t> index = op.getConstantIndex(); 223 if (!index.hasValue()) 224 return success(); 225 226 // Check that constant index is not knowingly out of range. 227 auto type = op.source().getType(); 228 if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 229 if (index.getValue() >= tensorType.getRank()) 230 return op.emitOpError("index is out of range"); 231 } else if (type.isa<UnrankedTensorType>()) { 232 // Assume index to be in range. 233 } else { 234 llvm_unreachable("expected operand with tensor type"); 235 } 236 return success(); 237 } 238 239 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 240 // All forms of folding require a known index. 241 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 242 if (!index) 243 return {}; 244 245 // Folding for unranked types (UnrankedTensorType) is not supported. 246 auto tensorType = source().getType().dyn_cast<RankedTensorType>(); 247 if (!tensorType) 248 return {}; 249 250 // Fold if the shape extent along the given index is known. 251 if (!tensorType.isDynamicDim(index.getInt())) { 252 Builder builder(getContext()); 253 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]); 254 } 255 256 Operation *definingOp = source().getDefiningOp(); 257 258 // Fold dim to the operand of tensor.generate. 259 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 260 auto resultType = 261 fromElements.getResult().getType().cast<RankedTensorType>(); 262 // The case where the type encodes the size of the dimension is handled 263 // above. 264 assert(resultType.getShape()[index.getInt()] == 265 RankedTensorType::kDynamicSize); 266 267 // Find the operand of the fromElements that corresponds to this index. 268 auto dynExtents = fromElements.dynamicExtents().begin(); 269 for (auto dim : resultType.getShape().take_front(index.getInt())) 270 if (dim == RankedTensorType::kDynamicSize) 271 dynExtents++; 272 273 return Value{*dynExtents}; 274 } 275 276 // The size at the given index is now known to be a dynamic size. 277 unsigned unsignedIndex = index.getValue().getZExtValue(); 278 279 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) { 280 assert(sliceOp.isDynamicSize(unsignedIndex) && 281 "Expected dynamic slice size"); 282 return sliceOp.getDynamicSize(unsignedIndex); 283 } 284 285 // dim(cast) -> dim 286 if (succeeded(foldTensorCast(*this))) 287 return getResult(); 288 289 return {}; 290 } 291 292 namespace { 293 /// Fold dim of a cast into the dim of the source of the tensor cast. 294 struct DimOfCastOp : public OpRewritePattern<DimOp> { 295 using OpRewritePattern<DimOp>::OpRewritePattern; 296 297 LogicalResult matchAndRewrite(DimOp dimOp, 298 PatternRewriter &rewriter) const override { 299 auto castOp = dimOp.source().getDefiningOp<CastOp>(); 300 if (!castOp) 301 return failure(); 302 Value newSource = castOp.getOperand(); 303 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 304 return success(); 305 } 306 }; 307 } // end anonymous namespace. 308 309 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 310 MLIRContext *context) { 311 results.add<DimOfCastOp>(context); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // ExtractOp 316 //===----------------------------------------------------------------------===// 317 318 static LogicalResult verify(ExtractOp op) { 319 // Verify the # indices match if we have a ranked type. 320 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 321 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 322 return op.emitOpError("incorrect number of indices for extract_element"); 323 324 return success(); 325 } 326 327 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 328 // The tensor operand must be a known constant. 329 Attribute tensor = operands.front(); 330 if (!tensor) 331 return {}; 332 // If this is a splat elements attribute, simply return the value. All of the 333 // elements of a splat attribute are the same. 334 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 335 return splatTensor.getSplatValue(); 336 337 // Otherwise, collect the constant indices into the tensor. 338 SmallVector<uint64_t, 8> indices; 339 for (Attribute indice : llvm::drop_begin(operands, 1)) { 340 if (!indice || !indice.isa<IntegerAttr>()) 341 return {}; 342 indices.push_back(indice.cast<IntegerAttr>().getInt()); 343 } 344 345 // If this is an elements attribute, query the value at the given indices. 346 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 347 if (elementsAttr && elementsAttr.isValidIndex(indices)) 348 return elementsAttr.getValue(indices); 349 return {}; 350 } 351 352 //===----------------------------------------------------------------------===// 353 // FromElementsOp 354 //===----------------------------------------------------------------------===// 355 356 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 357 Type elementType, ValueRange elements) { 358 Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())}, 359 elementType); 360 result.addOperands(elements); 361 result.addTypes(resultTy); 362 } 363 364 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 365 ValueRange elements) { 366 assert(!elements.empty() && "expected at least one element"); 367 build(builder, result, elements.front().getType(), elements); 368 } 369 370 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) { 371 if (!llvm::is_contained(operands, nullptr)) 372 return DenseElementsAttr::get(getType(), operands); 373 return {}; 374 } 375 376 namespace { 377 378 // Canonicalizes the pattern of the form 379 // 380 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 381 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> 382 // 383 // to just %element. 384 struct ExtractElementFromTensorFromElements 385 : public OpRewritePattern<tensor::ExtractOp> { 386 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 387 388 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 389 PatternRewriter &rewriter) const final { 390 if (extract.indices().size() != 1) 391 return failure(); 392 393 auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>(); 394 if (tensorFromElements == nullptr) 395 return failure(); 396 397 APInt index; 398 if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) 399 return failure(); 400 // Prevent out of bounds accesses. This can happen in invalid code that will 401 // never execute. 402 if (tensorFromElements->getNumOperands() <= index.getZExtValue() || 403 index.getSExtValue() < 0) 404 return failure(); 405 rewriter.replaceOp(extract, 406 tensorFromElements.getOperand(index.getZExtValue())); 407 return success(); 408 } 409 }; 410 411 } // namespace 412 413 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, 414 MLIRContext *context) { 415 results.add<ExtractElementFromTensorFromElements>(context); 416 } 417 418 //===----------------------------------------------------------------------===// 419 // InsertOp 420 //===----------------------------------------------------------------------===// 421 422 static LogicalResult verify(InsertOp op) { 423 // Verify the # indices match if we have a ranked type. 424 if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>()) 425 if (destType.getRank() != static_cast<int64_t>(op.indices().size())) 426 return op.emitOpError("incorrect number of indices"); 427 return success(); 428 } 429 430 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) { 431 Attribute scalar = operands[0]; 432 Attribute dest = operands[1]; 433 if (scalar && dest) 434 if (auto splatDest = dest.dyn_cast<SplatElementsAttr>()) 435 if (scalar == splatDest.getSplatValue()) 436 return dest; 437 return {}; 438 } 439 440 //===----------------------------------------------------------------------===// 441 // GenerateOp 442 //===----------------------------------------------------------------------===// 443 444 static LogicalResult verify(GenerateOp op) { 445 // Ensure that the tensor type has as many dynamic dimensions as are specified 446 // by the operands. 447 RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); 448 if (op.getNumOperands() != resultTy.getNumDynamicDims()) 449 return op.emitError("must have as many index operands as dynamic extents " 450 "in the result type"); 451 452 // Ensure that region arguments span the index space. 453 if (!llvm::all_of(op.body().getArgumentTypes(), 454 [](Type ty) { return ty.isIndex(); })) 455 return op.emitError("all body arguments must be index"); 456 if (op.body().getNumArguments() != resultTy.getRank()) 457 return op.emitError("must have one body argument per input dimension"); 458 459 // Ensure that the region yields an element of the right type. 460 auto yieldOp = 461 llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); 462 if (yieldOp.value().getType() != resultTy.getElementType()) 463 return op.emitOpError( 464 "body must be terminated with a `yield` operation of the tensor " 465 "element type"); 466 467 return success(); 468 } 469 470 void GenerateOp::build( 471 OpBuilder &b, OperationState &result, Type resultTy, 472 ValueRange dynamicExtents, 473 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 474 build(b, result, resultTy, dynamicExtents); 475 476 // Build and populate body. 477 OpBuilder::InsertionGuard guard(b); 478 Region *bodyRegion = result.regions.front().get(); 479 auto rank = resultTy.cast<RankedTensorType>().getRank(); 480 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); 481 Block *bodyBlock = 482 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); 483 bodyBuilder(b, result.location, bodyBlock->getArguments()); 484 } 485 486 namespace { 487 488 /// Canonicalizes tensor.generate operations with a constant 489 /// operand into the equivalent operation with the operand expressed in the 490 /// result type, instead. We also insert a type cast to make sure that the 491 /// resulting IR is still well-typed. 492 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { 493 using OpRewritePattern<GenerateOp>::OpRewritePattern; 494 495 LogicalResult matchAndRewrite(GenerateOp tensorFromElements, 496 PatternRewriter &rewriter) const final { 497 auto resultType = 498 tensorFromElements.getResult().getType().cast<RankedTensorType>(); 499 500 if (resultType.hasStaticShape()) 501 return failure(); 502 503 SmallVector<Value, 4> newOperands; 504 SmallVector<int64_t, 4> newShape; 505 auto operandsIt = tensorFromElements.dynamicExtents().begin(); 506 507 for (int64_t dim : resultType.getShape()) { 508 if (dim != RankedTensorType::kDynamicSize) { 509 newShape.push_back(dim); 510 continue; 511 } 512 APInt index; 513 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { 514 newShape.push_back(RankedTensorType::kDynamicSize); 515 newOperands.push_back(*operandsIt++); 516 continue; 517 } 518 newShape.push_back(index.getSExtValue()); 519 operandsIt++; 520 } 521 522 if (newOperands.size() == tensorFromElements.dynamicExtents().size()) 523 return failure(); 524 525 auto loc = tensorFromElements.getLoc(); 526 auto newOp = rewriter.create<GenerateOp>( 527 loc, RankedTensorType::get(newShape, resultType.getElementType()), 528 newOperands); 529 rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), 530 newOp.body().begin()); 531 rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, 532 newOp); 533 return success(); 534 } 535 }; 536 537 /// Canonicalizes the pattern of the form 538 /// 539 /// %tensor = tensor.generate %x { 540 /// ^bb0(%arg0: index): // no predecessors 541 /// <computation> 542 /// yield %1 : index 543 /// } : tensor<?xindex> 544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> 545 /// 546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the 547 /// tensor.generate operation has no side-effects. 548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { 549 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 550 551 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 552 PatternRewriter &rewriter) const final { 553 auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>(); 554 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) 555 return failure(); 556 557 BlockAndValueMapping mapping; 558 Block *body = tensorFromElements.getBody(); 559 mapping.map(body->getArguments(), extract.indices()); 560 for (auto &op : body->without_terminator()) 561 rewriter.clone(op, mapping); 562 563 auto yield = cast<YieldOp>(body->getTerminator()); 564 565 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); 566 return success(); 567 } 568 }; 569 570 /// Canonicalizes the pattern of the form 571 /// 572 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> 573 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> 574 /// 575 /// to 576 /// 577 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> 578 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { 579 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 580 581 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 582 PatternRewriter &rewriter) const final { 583 auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); 584 if (!tensorCast) 585 return failure(); 586 587 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), 588 extract.indices()); 589 return success(); 590 } 591 }; 592 593 } // namespace 594 595 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, 596 MLIRContext *context) { 597 // TODO: Move extract patterns to tensor::ExtractOp. 598 results.add<ExtractFromTensorGenerate, ExtractFromTensorCast, 599 StaticTensorGenerate>(context); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // ReshapeOp 604 //===----------------------------------------------------------------------===// 605 606 static int64_t GetNumElements(ShapedType type) { 607 int64_t numElements = 1; 608 for (auto dim : type.getShape()) 609 numElements *= dim; 610 return numElements; 611 } 612 613 static LogicalResult verify(ReshapeOp op) { 614 TensorType operandType = op.source().getType().cast<TensorType>(); 615 TensorType resultType = op.result().getType().cast<TensorType>(); 616 617 if (operandType.getElementType() != resultType.getElementType()) 618 return op.emitOpError("element types of source and destination tensor " 619 "types should be the same"); 620 621 int64_t shapeSize = 622 op.shape().getType().cast<RankedTensorType>().getDimSize(0); 623 auto resultRankedType = resultType.dyn_cast<RankedTensorType>(); 624 auto operandRankedType = operandType.dyn_cast<RankedTensorType>(); 625 626 if (resultRankedType) { 627 if (operandRankedType && resultRankedType.hasStaticShape() && 628 operandRankedType.hasStaticShape()) { 629 if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType)) 630 return op.emitOpError("source and destination tensor should have the " 631 "same number of elements"); 632 } 633 if (shapeSize == TensorType::kDynamicSize) 634 return op.emitOpError("cannot use shape operand with dynamic length to " 635 "reshape to statically-ranked tensor type"); 636 if (shapeSize != resultRankedType.getRank()) 637 return op.emitOpError( 638 "length of shape operand differs from the result's tensor rank"); 639 } 640 return success(); 641 } 642 643 //===----------------------------------------------------------------------===// 644 // ExtractSliceOp 645 //===----------------------------------------------------------------------===// 646 647 /// An extract_slice op result type can be fully inferred from the source type 648 /// and the static representation of offsets, sizes and strides. Special 649 /// sentinels encode the dynamic case. 650 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, 651 ArrayRef<int64_t> leadingStaticOffsets, 652 ArrayRef<int64_t> leadingStaticSizes, 653 ArrayRef<int64_t> leadingStaticStrides) { 654 // An extract_slice op may specify only a leading subset of offset/sizes/ 655 // strides in which case we complete with offset=0, sizes from memref type and 656 // strides=1. 657 unsigned rank = sourceRankedTensorType.getRank(); 658 assert(leadingStaticSizes.size() <= rank && 659 "unexpected leadingStaticSizes overflow"); 660 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 661 unsigned numTrailingSizes = rank - staticSizes.size(); 662 llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back( 663 numTrailingSizes)); 664 return RankedTensorType::get(staticSizes, 665 sourceRankedTensorType.getElementType()); 666 } 667 668 Type ExtractSliceOp::inferResultType( 669 RankedTensorType sourceRankedTensorType, 670 ArrayRef<OpFoldResult> leadingStaticOffsets, 671 ArrayRef<OpFoldResult> leadingStaticSizes, 672 ArrayRef<OpFoldResult> leadingStaticStrides) { 673 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 674 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 675 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 676 staticOffsets, ShapedType::kDynamicStrideOrOffset); 677 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 678 ShapedType::kDynamicSize); 679 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 680 staticStrides, ShapedType::kDynamicStrideOrOffset); 681 return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, 682 staticSizes, staticStrides); 683 } 684 685 /// An extract_slice op result type can be fully inferred from the source type 686 /// and the static representation of offsets, sizes and strides. Special 687 /// sentinels encode the dynamic case. 688 Type ExtractSliceOp::inferRankReducedResultType( 689 unsigned resultRank, RankedTensorType sourceRankedTensorType, 690 ArrayRef<int64_t> leadingStaticOffsets, 691 ArrayRef<int64_t> leadingStaticSizes, 692 ArrayRef<int64_t> leadingStaticStrides) { 693 auto inferredType = 694 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 695 leadingStaticSizes, leadingStaticStrides) 696 .cast<RankedTensorType>(); 697 int rankDiff = inferredType.getRank() - resultRank; 698 if (rankDiff > 0) { 699 auto shape = inferredType.getShape(); 700 llvm::SmallDenseSet<unsigned> dimsToProject; 701 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 702 SmallVector<int64_t> projectedShape; 703 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 704 if (!dimsToProject.contains(pos)) 705 projectedShape.push_back(shape[pos]); 706 inferredType = 707 RankedTensorType::get(projectedShape, inferredType.getElementType()); 708 } 709 return inferredType; 710 } 711 712 Type ExtractSliceOp::inferRankReducedResultType( 713 unsigned resultRank, RankedTensorType sourceRankedTensorType, 714 ArrayRef<OpFoldResult> leadingStaticOffsets, 715 ArrayRef<OpFoldResult> leadingStaticSizes, 716 ArrayRef<OpFoldResult> leadingStaticStrides) { 717 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 718 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 719 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 720 staticOffsets, ShapedType::kDynamicStrideOrOffset); 721 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 722 ShapedType::kDynamicSize); 723 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 724 staticStrides, ShapedType::kDynamicStrideOrOffset); 725 return ExtractSliceOp::inferRankReducedResultType( 726 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 727 staticStrides); 728 } 729 730 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom 731 /// result type. If the type passed is nullptr, it is inferred. 732 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 733 RankedTensorType resultType, Value source, 734 ArrayRef<OpFoldResult> offsets, 735 ArrayRef<OpFoldResult> sizes, 736 ArrayRef<OpFoldResult> strides, 737 ArrayRef<NamedAttribute> attrs) { 738 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 739 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 740 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 741 ShapedType::kDynamicStrideOrOffset); 742 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 743 ShapedType::kDynamicSize); 744 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 745 ShapedType::kDynamicStrideOrOffset); 746 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>(); 747 // Structuring implementation this way avoids duplication between builders. 748 if (!resultType) { 749 resultType = 750 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, 751 staticSizes, staticStrides) 752 .cast<RankedTensorType>(); 753 } 754 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 755 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 756 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 757 result.addAttributes(attrs); 758 } 759 760 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred 761 /// result type. 762 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 763 ArrayRef<OpFoldResult> offsets, 764 ArrayRef<OpFoldResult> sizes, 765 ArrayRef<OpFoldResult> strides, 766 ArrayRef<NamedAttribute> attrs) { 767 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 768 } 769 770 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the 771 /// type passed is nullptr, it is inferred. 772 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 773 RankedTensorType resultType, Value source, 774 ValueRange offsets, ValueRange sizes, 775 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 776 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 777 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 778 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 779 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 780 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 781 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 782 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 783 } 784 785 /// Build an ExtractSliceOp with dynamic entries and inferred result type. 786 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 787 ValueRange offsets, ValueRange sizes, 788 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 789 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 790 } 791 792 enum SliceVerificationResult { 793 Success, 794 RankTooLarge, 795 SizeMismatch, 796 ElemTypeMismatch, 797 }; 798 799 /// Checks if `original` Type type can be rank reduced to `reduced` type. 800 /// This function is slight variant of `is subsequence` algorithm where 801 /// not matching dimension must be 1. 802 static SliceVerificationResult 803 isRankReducedType(Type originalType, Type candidateReducedType, 804 std::string *errMsg = nullptr) { 805 if (originalType == candidateReducedType) 806 return SliceVerificationResult::Success; 807 if (!originalType.isa<RankedTensorType>()) 808 return SliceVerificationResult::Success; 809 if (originalType.isa<RankedTensorType>() && 810 !candidateReducedType.isa<RankedTensorType>()) 811 return SliceVerificationResult::Success; 812 813 ShapedType originalShapedType = originalType.cast<ShapedType>(); 814 ShapedType candidateReducedShapedType = 815 candidateReducedType.cast<ShapedType>(); 816 817 // Rank and size logic is valid for all ShapedTypes. 818 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 819 ArrayRef<int64_t> candidateReducedShape = 820 candidateReducedShapedType.getShape(); 821 unsigned originalRank = originalShape.size(), 822 candidateReducedRank = candidateReducedShape.size(); 823 if (candidateReducedRank > originalRank) 824 return SliceVerificationResult::RankTooLarge; 825 826 auto optionalUnusedDimsMask = 827 computeRankReductionMask(originalShape, candidateReducedShape); 828 829 // Sizes cannot be matched in case empty vector is returned. 830 if (!optionalUnusedDimsMask.hasValue()) 831 return SliceVerificationResult::SizeMismatch; 832 833 if (originalShapedType.getElementType() != 834 candidateReducedShapedType.getElementType()) 835 return SliceVerificationResult::ElemTypeMismatch; 836 837 // We are done for the tensor case. 838 if (originalType.isa<RankedTensorType>()) 839 return SliceVerificationResult::Success; 840 841 return SliceVerificationResult::Success; 842 } 843 844 template <typename OpTy> 845 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, 846 OpTy op, Type expectedType, 847 StringRef errMsg = "") { 848 auto memrefType = expectedType.cast<ShapedType>(); 849 switch (result) { 850 case SliceVerificationResult::Success: 851 return success(); 852 case SliceVerificationResult::RankTooLarge: 853 return op.emitError("expected result rank to be smaller or equal to ") 854 << "the source rank. " << errMsg; 855 case SliceVerificationResult::SizeMismatch: 856 return op.emitError("expected result type to be ") 857 << expectedType 858 << " or a rank-reduced version. (mismatch of result sizes) " 859 << errMsg; 860 case SliceVerificationResult::ElemTypeMismatch: 861 return op.emitError("expected result element type to be ") 862 << memrefType.getElementType() << errMsg; 863 } 864 llvm_unreachable("unexpected extract_slice op verification result"); 865 } 866 867 /// Verifier for ExtractSliceOp. 868 static LogicalResult verify(ExtractSliceOp op) { 869 // Verify result type against inferred type. 870 auto expectedType = ExtractSliceOp::inferResultType( 871 op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), 872 extractFromI64ArrayAttr(op.static_sizes()), 873 extractFromI64ArrayAttr(op.static_strides())); 874 auto result = isRankReducedType(expectedType, op.getType()); 875 return produceSliceErrorMsg(result, op, expectedType); 876 } 877 878 /// Infer the canonical type of the result of an extract_slice op. Returns a 879 /// type with rank `resultRank` that is either the rank of the rank-reduced 880 /// type, or the non-rank-reduced type. 881 static RankedTensorType 882 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType, 883 ArrayRef<OpFoldResult> mixedOffsets, 884 ArrayRef<OpFoldResult> mixedSizes, 885 ArrayRef<OpFoldResult> mixedStrides) { 886 auto resultType = 887 ExtractSliceOp::inferRankReducedResultType( 888 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 889 .cast<RankedTensorType>(); 890 if (resultType.getRank() != resultRank) { 891 resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets, 892 mixedSizes, mixedStrides) 893 .cast<RankedTensorType>(); 894 } 895 return resultType; 896 } 897 898 namespace { 899 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments. 900 /// This essentially pushes memref_cast past its consuming slice when 901 /// `canFoldIntoConsumerOp` is true. 902 /// 903 /// Example: 904 /// ``` 905 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32> 906 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to 907 /// tensor<3x4xf32> 908 /// ``` 909 /// is rewritten into: 910 /// ``` 911 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to 912 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32> 913 /// ``` 914 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> { 915 public: 916 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 917 918 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 919 PatternRewriter &rewriter) const override { 920 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 921 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) { 922 return matchPattern(operand, matchConstantIndex()); 923 })) 924 return failure(); 925 926 auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>(); 927 if (!castOp) 928 return failure(); 929 930 if (!canFoldIntoConsumerOp(castOp)) 931 return failure(); 932 933 /// Deduce the type of the result to use for the canonicalized operation. 934 RankedTensorType resultType = getCanonicalSliceResultType( 935 sliceOp.getType().getRank(), sliceOp.getSourceType(), 936 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), 937 sliceOp.getMixedStrides()); 938 Value newSlice = rewriter.create<ExtractSliceOp>( 939 sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(), 940 sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), 941 sliceOp.static_sizes(), sliceOp.static_strides()); 942 rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(), 943 newSlice); 944 return success(); 945 } 946 }; 947 } // namespace 948 949 /// Return the canonical type of the result of an extract_slice op. 950 struct SliceReturnTypeCanonicalizer { 951 RankedTensorType operator()(ExtractSliceOp op, 952 ArrayRef<OpFoldResult> mixedOffsets, 953 ArrayRef<OpFoldResult> mixedSizes, 954 ArrayRef<OpFoldResult> mixedStrides) { 955 return getCanonicalSliceResultType(op.getType().getRank(), 956 op.getSourceType(), mixedOffsets, 957 mixedSizes, mixedStrides); 958 } 959 }; 960 961 /// A canonicalizer wrapper to replace ExtractSliceOps. 962 struct SliceCanonicalizer { 963 void operator()(PatternRewriter &rewriter, ExtractSliceOp op, 964 ExtractSliceOp newOp) { 965 Value replacement = newOp.getResult(); 966 if (replacement.getType() != op.getType()) 967 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), 968 replacement); 969 rewriter.replaceOp(op, replacement); 970 } 971 }; 972 973 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 974 MLIRContext *context) { 975 results.add< 976 OpWithOffsetSizesAndStridesConstantArgumentFolder< 977 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, 978 ExtractSliceOpCastFolder>(context); 979 } 980 981 // 982 static LogicalResult 983 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, 984 ShapedType shapedType) { 985 OpBuilder b(op.getContext()); 986 for (OpFoldResult ofr : op.getMixedOffsets()) 987 if (getConstantIntValue(ofr) != static_cast<int64_t>(0)) 988 return failure(); 989 // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip 990 // is appropriate. 991 auto shape = shapedType.getShape(); 992 for (auto it : llvm::zip(op.getMixedSizes(), shape)) 993 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it)) 994 return failure(); 995 for (OpFoldResult ofr : op.getMixedStrides()) 996 if (getConstantIntValue(ofr) != static_cast<int64_t>(1)) 997 return failure(); 998 return success(); 999 } 1000 1001 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) { 1002 if (getSourceType() == getType() && 1003 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 1004 return this->source(); 1005 return OpFoldResult(); 1006 } 1007 1008 //===----------------------------------------------------------------------===// 1009 // InsertSliceOp 1010 //===----------------------------------------------------------------------===// 1011 1012 // Build a InsertSliceOp with mixed static and dynamic entries. 1013 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 1014 Value dest, ArrayRef<OpFoldResult> offsets, 1015 ArrayRef<OpFoldResult> sizes, 1016 ArrayRef<OpFoldResult> strides, 1017 ArrayRef<NamedAttribute> attrs) { 1018 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1019 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1020 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1021 ShapedType::kDynamicStrideOrOffset); 1022 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1023 ShapedType::kDynamicSize); 1024 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1025 ShapedType::kDynamicStrideOrOffset); 1026 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, 1027 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1028 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1029 result.addAttributes(attrs); 1030 } 1031 1032 // Build a InsertSliceOp with dynamic entries. 1033 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 1034 Value dest, ValueRange offsets, ValueRange sizes, 1035 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 1036 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1037 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1038 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1039 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1040 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1041 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1042 build(b, result, source, dest, offsetValues, sizeValues, strideValues); 1043 } 1044 1045 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) { 1046 if (getSourceType().hasStaticShape() && getType().hasStaticShape() && 1047 getSourceType() == getType() && 1048 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 1049 return this->source(); 1050 return OpFoldResult(); 1051 } 1052 1053 LogicalResult InsertSliceOp::reifyResultShapes( 1054 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1055 reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank())); 1056 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) { 1057 reifiedReturnShapes[0][dim] = 1058 builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim); 1059 } 1060 return success(); 1061 } 1062 1063 namespace { 1064 /// Pattern to rewrite a insert_slice op with constant arguments. 1065 class InsertSliceOpConstantArgumentFolder final 1066 : public OpRewritePattern<InsertSliceOp> { 1067 public: 1068 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 1069 1070 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 1071 PatternRewriter &rewriter) const override { 1072 // No constant operand, just return. 1073 if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { 1074 return matchPattern(operand, matchConstantIndex()); 1075 })) 1076 return failure(); 1077 1078 // At least one of offsets/sizes/strides is a new constant. 1079 // Form the new list of operands and constant attributes from the 1080 // existing. 1081 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets()); 1082 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes()); 1083 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides()); 1084 canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); 1085 canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); 1086 canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); 1087 1088 // Create the new op in canonical form. 1089 rewriter.replaceOpWithNewOp<InsertSliceOp>( 1090 insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(), 1091 mixedOffsets, mixedSizes, mixedStrides); 1092 return success(); 1093 } 1094 }; 1095 1096 /// Fold tensor_casts with insert_slice operations. If the source or destination 1097 /// tensor is a tensor_cast that removes static type information, the cast is 1098 /// folded into the insert_slice operation. E.g.: 1099 /// 1100 /// ```mlir 1101 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 1102 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ... 1103 /// ``` 1104 /// 1105 /// folds into: 1106 /// 1107 /// ```mlir 1108 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ... 1109 /// ``` 1110 /// 1111 /// Note: When folding a cast on the destination tensor, the result of the 1112 /// insert_slice operation is casted to ensure that the type of the result did 1113 /// not change. 1114 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> { 1115 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 1116 1117 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 1118 PatternRewriter &rewriter) const override { 1119 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) { 1120 return matchPattern(operand, matchConstantIndex()); 1121 })) 1122 return failure(); 1123 1124 auto getSourceOfCastOp = [](Value v) -> Optional<Value> { 1125 auto castOp = v.getDefiningOp<tensor::CastOp>(); 1126 if (!castOp || !canFoldIntoConsumerOp(castOp)) 1127 return llvm::None; 1128 return castOp.source(); 1129 }; 1130 Optional<Value> sourceCastSource = 1131 getSourceOfCastOp(insertSliceOp.source()); 1132 Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest()); 1133 if (!sourceCastSource && !destCastSource) 1134 return failure(); 1135 1136 Value replacement = rewriter.create<InsertSliceOp>( 1137 insertSliceOp.getLoc(), 1138 (sourceCastSource ? *sourceCastSource : insertSliceOp.source()), 1139 (destCastSource ? *destCastSource : insertSliceOp.dest()), 1140 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 1141 insertSliceOp.getMixedStrides()); 1142 1143 if (replacement.getType() != insertSliceOp.getType()) { 1144 replacement = rewriter.create<tensor::CastOp>( 1145 insertSliceOp.getLoc(), insertSliceOp.getType(), replacement); 1146 } 1147 rewriter.replaceOp(insertSliceOp, replacement); 1148 return success(); 1149 } 1150 }; 1151 1152 /// If additional static type information can be deduced from a insert_slice's 1153 /// size operands, insert an explicit cast of the op's source operand. This 1154 /// enables other canonicalization patterns that are matching for tensor_cast 1155 /// ops such as `ForOpTensorCastFolder` in SCF. 1156 /// 1157 /// Example: 1158 /// 1159 /// ```mlir 1160 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1] 1161 /// : tensor<?x?xf32> into ... 1162 /// ``` 1163 /// 1164 /// folds into: 1165 /// 1166 /// ```mlir 1167 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32> 1168 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1] 1169 /// : tensor<64x64xf32> into ... 1170 /// ``` 1171 struct InsertSliceOpSourceCastInserter final 1172 : public OpRewritePattern<InsertSliceOp> { 1173 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 1174 1175 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 1176 PatternRewriter &rewriter) const override { 1177 RankedTensorType srcType = insertSliceOp.getSourceType(); 1178 if (srcType.getRank() != insertSliceOp.getType().getRank()) 1179 return failure(); 1180 SmallVector<int64_t> newSrcShape(srcType.getShape().begin(), 1181 srcType.getShape().end()); 1182 for (int64_t i = 0; i < srcType.getRank(); ++i) { 1183 if (Optional<int64_t> constInt = 1184 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) 1185 newSrcShape[i] = *constInt; 1186 } 1187 RankedTensorType newSrcType = 1188 RankedTensorType::get(newSrcShape, srcType.getElementType()); 1189 if (srcType == newSrcType) 1190 return failure(); 1191 1192 // srcType and newSrcType are different. Insert a cast. 1193 Value cast = rewriter.create<tensor::CastOp>( 1194 insertSliceOp.getLoc(), newSrcType, insertSliceOp.source()); 1195 rewriter.replaceOpWithNewOp<InsertSliceOp>( 1196 insertSliceOp, cast, insertSliceOp.dest(), 1197 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 1198 insertSliceOp.getMixedStrides()); 1199 return success(); 1200 } 1201 }; 1202 } // namespace 1203 1204 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 1205 MLIRContext *context) { 1206 results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder, 1207 InsertSliceOpSourceCastInserter>(context); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // TableGen'd op method definitions 1212 //===----------------------------------------------------------------------===// 1213 1214 #define GET_OP_CLASSES 1215 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 1216