1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 11 #include "mlir/Dialect/Tensor/IR/Tensor.h" 12 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinAttributeInterfaces.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "llvm/ADT/STLExtras.h" 21 22 using namespace mlir; 23 using namespace mlir::tensor; 24 25 /// Materialize a single constant operation from a given attribute value with 26 /// the desired resultant type. 27 Operation *TensorDialect::materializeConstant(OpBuilder &builder, 28 Attribute value, Type type, 29 Location loc) { 30 if (arith::ConstantOp::isBuildableWith(value, type)) 31 return builder.create<arith::ConstantOp>(loc, value, type); 32 if (ConstantOp::isBuildableWith(value, type)) 33 return builder.create<ConstantOp>(loc, value, type); 34 return nullptr; 35 } 36 37 //===----------------------------------------------------------------------===// 38 // CastOp 39 //===----------------------------------------------------------------------===// 40 41 /// Returns true if `target` is a ranked tensor type that preserves static 42 /// information available in the `source` ranked tensor type. 43 bool mlir::tensor::preservesStaticInformation(Type source, Type target) { 44 auto sourceType = source.dyn_cast<RankedTensorType>(); 45 auto targetType = target.dyn_cast<RankedTensorType>(); 46 47 // Requires RankedTensorType. 48 if (!sourceType || !targetType) 49 return false; 50 51 // Requires same elemental type. 52 if (sourceType.getElementType() != targetType.getElementType()) 53 return false; 54 55 // Requires same rank. 56 if (sourceType.getRank() != targetType.getRank()) 57 return false; 58 59 // If cast is towards more static sizes along any dimension, don't fold. 60 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { 61 if (!ShapedType::isDynamic(std::get<0>(t)) && 62 ShapedType::isDynamic(std::get<1>(t))) 63 return false; 64 } 65 66 return true; 67 } 68 69 /// Determines whether tensor::CastOp casts to a more dynamic version of the 70 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 71 /// implement canonicalization patterns for ops in different dialects that may 72 /// consume the results of tensor.cast operations. Such foldable tensor.cast 73 /// operations are typically inserted as `slice` ops and are canonicalized, 74 /// to preserve the type compatibility of their uses. 75 /// 76 /// Returns true when all conditions are met: 77 /// 1. source and result are ranked tensors with same element type and rank. 78 /// 2. the tensor type has more static information than the result 79 /// 80 /// Example: 81 /// ```mlir 82 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 83 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 84 /// ``` 85 /// 86 /// folds into: 87 /// 88 /// ```mlir 89 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 90 /// ``` 91 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 92 if (!castOp) 93 return false; 94 95 // Can fold if the source of cast has at least as much static information as 96 // its results. 97 return preservesStaticInformation(castOp.getType(), 98 castOp.source().getType()); 99 } 100 101 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp 102 /// that can be folded. 103 LogicalResult mlir::tensor::foldTensorCast(Operation *op) { 104 bool folded = false; 105 for (OpOperand &operand : op->getOpOperands()) { 106 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 107 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 108 operand.set(castOp.getOperand()); 109 folded = true; 110 } 111 } 112 return success(folded); 113 } 114 115 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 116 if (inputs.size() != 1 || outputs.size() != 1) 117 return false; 118 Type a = inputs.front(), b = outputs.front(); 119 auto aT = a.dyn_cast<TensorType>(); 120 auto bT = b.dyn_cast<TensorType>(); 121 if (!aT || !bT) 122 return false; 123 124 if (aT.getElementType() != bT.getElementType()) 125 return false; 126 127 return succeeded(verifyCompatibleShape(aT, bT)); 128 } 129 130 /// Compute a TensorType that has the joined shape knowledge of the two 131 /// given TensorTypes. The element types need to match. 132 static TensorType joinShapes(TensorType one, TensorType two) { 133 assert(one.getElementType() == two.getElementType()); 134 135 if (!one.hasRank()) 136 return two; 137 if (!two.hasRank()) 138 return one; 139 140 int64_t rank = one.getRank(); 141 if (rank != two.getRank()) 142 return {}; 143 144 SmallVector<int64_t, 4> join; 145 join.reserve(rank); 146 for (int64_t i = 0; i < rank; ++i) { 147 if (one.isDynamicDim(i)) { 148 join.push_back(two.getDimSize(i)); 149 continue; 150 } 151 if (two.isDynamicDim(i)) { 152 join.push_back(one.getDimSize(i)); 153 continue; 154 } 155 if (one.getDimSize(i) != two.getDimSize(i)) 156 return {}; 157 join.push_back(one.getDimSize(i)); 158 } 159 return RankedTensorType::get(join, one.getElementType()); 160 } 161 162 namespace { 163 164 /// Replaces chains of two tensor.cast operations by a single tensor.cast 165 /// operation if doing so does not remove runtime constraints. 166 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 167 using OpRewritePattern<CastOp>::OpRewritePattern; 168 169 LogicalResult matchAndRewrite(CastOp tensorCast, 170 PatternRewriter &rewriter) const final { 171 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 172 173 if (!tensorCastOperand) 174 return failure(); 175 176 auto sourceType = 177 tensorCastOperand.getOperand().getType().cast<TensorType>(); 178 auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); 179 auto resultType = tensorCast.getType().cast<TensorType>(); 180 181 // We can remove the intermediate cast if joining all three produces the 182 // same result as just joining the source and result shapes. 183 auto firstJoin = 184 joinShapes(joinShapes(sourceType, intermediateType), resultType); 185 186 // The join might not exist if the cast sequence would fail at runtime. 187 if (!firstJoin) 188 return failure(); 189 190 // The newJoin always exists if the above join exists, it might just contain 191 // less information. If so, we cannot drop the intermediate cast, as doing 192 // so would remove runtime checks. 193 auto newJoin = joinShapes(sourceType, resultType); 194 if (firstJoin != newJoin) 195 return failure(); 196 197 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 198 tensorCastOperand.getOperand()); 199 return success(); 200 } 201 }; 202 203 } // namespace 204 205 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, 206 MLIRContext *context) { 207 results.add<ChainedTensorCast>(context); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // DimOp 212 //===----------------------------------------------------------------------===// 213 214 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 215 int64_t index) { 216 auto loc = result.location; 217 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 218 build(builder, result, source, indexValue); 219 } 220 221 Optional<int64_t> DimOp::getConstantIndex() { 222 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 223 return constantOp.getValue().cast<IntegerAttr>().getInt(); 224 return {}; 225 } 226 227 static LogicalResult verify(DimOp op) { 228 // Assume unknown index to be in range. 229 Optional<int64_t> index = op.getConstantIndex(); 230 if (!index.hasValue()) 231 return success(); 232 233 // Check that constant index is not knowingly out of range. 234 auto type = op.source().getType(); 235 if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 236 if (index.getValue() >= tensorType.getRank()) 237 return op.emitOpError("index is out of range"); 238 } else if (type.isa<UnrankedTensorType>()) { 239 // Assume index to be in range. 240 } else { 241 llvm_unreachable("expected operand with tensor type"); 242 } 243 return success(); 244 } 245 246 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 247 // All forms of folding require a known index. 248 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 249 if (!index) 250 return {}; 251 252 // Folding for unranked types (UnrankedTensorType) is not supported. 253 auto tensorType = source().getType().dyn_cast<RankedTensorType>(); 254 if (!tensorType) 255 return {}; 256 257 // Fold if the shape extent along the given index is known. 258 if (!tensorType.isDynamicDim(index.getInt())) { 259 Builder builder(getContext()); 260 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]); 261 } 262 263 Operation *definingOp = source().getDefiningOp(); 264 265 // Fold dim to the operand of tensor.generate. 266 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 267 auto resultType = 268 fromElements.getResult().getType().cast<RankedTensorType>(); 269 // The case where the type encodes the size of the dimension is handled 270 // above. 271 assert(resultType.getShape()[index.getInt()] == 272 RankedTensorType::kDynamicSize); 273 274 // Find the operand of the fromElements that corresponds to this index. 275 auto dynExtents = fromElements.dynamicExtents().begin(); 276 for (auto dim : resultType.getShape().take_front(index.getInt())) 277 if (dim == RankedTensorType::kDynamicSize) 278 dynExtents++; 279 280 return Value{*dynExtents}; 281 } 282 283 // The size at the given index is now known to be a dynamic size. 284 unsigned unsignedIndex = index.getValue().getZExtValue(); 285 286 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) { 287 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on 288 // `resolve-shaped-type-result-dims` pass. 289 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() && 290 sliceOp.isDynamicSize(unsignedIndex)) { 291 return {sliceOp.getDynamicSize(unsignedIndex)}; 292 } 293 } 294 295 // dim(cast) -> dim 296 if (succeeded(foldTensorCast(*this))) 297 return getResult(); 298 299 return {}; 300 } 301 302 namespace { 303 /// Fold dim of a cast into the dim of the source of the tensor cast. 304 struct DimOfCastOp : public OpRewritePattern<DimOp> { 305 using OpRewritePattern<DimOp>::OpRewritePattern; 306 307 LogicalResult matchAndRewrite(DimOp dimOp, 308 PatternRewriter &rewriter) const override { 309 auto castOp = dimOp.source().getDefiningOp<CastOp>(); 310 if (!castOp) 311 return failure(); 312 Value newSource = castOp.getOperand(); 313 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 314 return success(); 315 } 316 }; 317 } // namespace 318 319 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 320 MLIRContext *context) { 321 results.add<DimOfCastOp>(context); 322 } 323 324 //===----------------------------------------------------------------------===// 325 // ExtractOp 326 //===----------------------------------------------------------------------===// 327 328 static LogicalResult verify(ExtractOp op) { 329 // Verify the # indices match if we have a ranked type. 330 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 331 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 332 return op.emitOpError("incorrect number of indices for extract_element"); 333 334 return success(); 335 } 336 337 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 338 // The tensor operand must be a known constant. 339 Attribute tensor = operands.front(); 340 if (!tensor) 341 return {}; 342 // If this is a splat elements attribute, simply return the value. All of the 343 // elements of a splat attribute are the same. 344 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 345 return splatTensor.getSplatValue<Attribute>(); 346 347 // Otherwise, collect the constant indices into the tensor. 348 SmallVector<uint64_t, 8> indices; 349 for (Attribute indice : llvm::drop_begin(operands, 1)) { 350 if (!indice || !indice.isa<IntegerAttr>()) 351 return {}; 352 indices.push_back(indice.cast<IntegerAttr>().getInt()); 353 } 354 355 // If this is an elements attribute, query the value at the given indices. 356 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 357 if (elementsAttr && elementsAttr.isValidIndex(indices)) 358 return elementsAttr.getValues<Attribute>()[indices]; 359 return {}; 360 } 361 362 //===----------------------------------------------------------------------===// 363 // FromElementsOp 364 //===----------------------------------------------------------------------===// 365 366 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 367 Type resultType, ValueRange elements) { 368 result.addOperands(elements); 369 result.addTypes(resultType); 370 } 371 372 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 373 ValueRange elements) { 374 assert(!elements.empty() && "expected at least one element"); 375 Type resultType = RankedTensorType::get( 376 {static_cast<int64_t>(elements.size())}, elements.front().getType()); 377 build(builder, result, resultType, elements); 378 } 379 380 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) { 381 if (!llvm::is_contained(operands, nullptr)) 382 return DenseElementsAttr::get(getType(), operands); 383 return {}; 384 } 385 386 namespace { 387 388 // Canonicalizes the pattern of the form 389 // 390 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 391 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> 392 // 393 // to just %element. 394 struct ExtractElementFromTensorFromElements 395 : public OpRewritePattern<tensor::ExtractOp> { 396 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 397 398 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 399 PatternRewriter &rewriter) const final { 400 auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>(); 401 if (!tensorFromElements) 402 return failure(); 403 auto tensorType = tensorFromElements.getType().cast<RankedTensorType>(); 404 auto rank = tensorType.getRank(); 405 if (rank == 0) { 406 rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); 407 return success(); 408 } 409 SmallVector<APInt, 3> indices(rank); 410 int64_t flatIndex = 0; 411 int64_t stride = 1; 412 for (int i = rank - 1; i >= 0; --i) { 413 APInt index; 414 if (!matchPattern(extract.indices()[i], m_ConstantInt(&index))) 415 return failure(); 416 if (i < rank - 1) 417 stride *= tensorType.getDimSize(i); 418 flatIndex += index.getSExtValue() * stride; 419 } 420 // Prevent out of bounds accesses. This can happen in invalid code that will 421 // never execute. 422 if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0) 423 return failure(); 424 rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex)); 425 return success(); 426 } 427 }; 428 429 } // namespace 430 431 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, 432 MLIRContext *context) { 433 results.add<ExtractElementFromTensorFromElements>(context); 434 } 435 436 //===----------------------------------------------------------------------===// 437 // InsertOp 438 //===----------------------------------------------------------------------===// 439 440 static LogicalResult verify(InsertOp op) { 441 // Verify the # indices match if we have a ranked type. 442 if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>()) 443 if (destType.getRank() != static_cast<int64_t>(op.indices().size())) 444 return op.emitOpError("incorrect number of indices"); 445 return success(); 446 } 447 448 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) { 449 Attribute scalar = operands[0]; 450 Attribute dest = operands[1]; 451 if (scalar && dest) 452 if (auto splatDest = dest.dyn_cast<SplatElementsAttr>()) 453 if (scalar == splatDest.getSplatValue<Attribute>()) 454 return dest; 455 return {}; 456 } 457 458 //===----------------------------------------------------------------------===// 459 // GenerateOp 460 //===----------------------------------------------------------------------===// 461 462 static LogicalResult verify(GenerateOp op) { 463 // Ensure that the tensor type has as many dynamic dimensions as are specified 464 // by the operands. 465 RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); 466 if (op.getNumOperands() != resultTy.getNumDynamicDims()) 467 return op.emitError("must have as many index operands as dynamic extents " 468 "in the result type"); 469 470 // Ensure that region arguments span the index space. 471 if (!llvm::all_of(op.body().getArgumentTypes(), 472 [](Type ty) { return ty.isIndex(); })) 473 return op.emitError("all body arguments must be index"); 474 if (op.body().getNumArguments() != resultTy.getRank()) 475 return op.emitError("must have one body argument per input dimension"); 476 477 // Ensure that the region yields an element of the right type. 478 auto yieldOp = 479 llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); 480 if (yieldOp.value().getType() != resultTy.getElementType()) 481 return op.emitOpError( 482 "body must be terminated with a `yield` operation of the tensor " 483 "element type"); 484 485 return success(); 486 } 487 488 void GenerateOp::build( 489 OpBuilder &b, OperationState &result, Type resultTy, 490 ValueRange dynamicExtents, 491 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 492 build(b, result, resultTy, dynamicExtents); 493 494 // Build and populate body. 495 OpBuilder::InsertionGuard guard(b); 496 Region *bodyRegion = result.regions.front().get(); 497 auto rank = resultTy.cast<RankedTensorType>().getRank(); 498 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); 499 Block *bodyBlock = 500 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); 501 bodyBuilder(b, result.location, bodyBlock->getArguments()); 502 } 503 504 namespace { 505 506 /// Canonicalizes tensor.generate operations with a constant 507 /// operand into the equivalent operation with the operand expressed in the 508 /// result type, instead. We also insert a type cast to make sure that the 509 /// resulting IR is still well-typed. 510 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { 511 using OpRewritePattern<GenerateOp>::OpRewritePattern; 512 513 LogicalResult matchAndRewrite(GenerateOp tensorFromElements, 514 PatternRewriter &rewriter) const final { 515 auto resultType = 516 tensorFromElements.getResult().getType().cast<RankedTensorType>(); 517 518 if (resultType.hasStaticShape()) 519 return failure(); 520 521 SmallVector<Value, 4> newOperands; 522 SmallVector<int64_t, 4> newShape; 523 auto operandsIt = tensorFromElements.dynamicExtents().begin(); 524 525 for (int64_t dim : resultType.getShape()) { 526 if (dim != RankedTensorType::kDynamicSize) { 527 newShape.push_back(dim); 528 continue; 529 } 530 APInt index; 531 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { 532 newShape.push_back(RankedTensorType::kDynamicSize); 533 newOperands.push_back(*operandsIt++); 534 continue; 535 } 536 newShape.push_back(index.getSExtValue()); 537 operandsIt++; 538 } 539 540 if (newOperands.size() == tensorFromElements.dynamicExtents().size()) 541 return failure(); 542 543 auto loc = tensorFromElements.getLoc(); 544 auto newOp = rewriter.create<GenerateOp>( 545 loc, RankedTensorType::get(newShape, resultType.getElementType()), 546 newOperands); 547 rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), 548 newOp.body().begin()); 549 rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, 550 newOp); 551 return success(); 552 } 553 }; 554 555 /// Canonicalizes the pattern of the form 556 /// 557 /// %tensor = tensor.generate %x { 558 /// ^bb0(%arg0: index): // no predecessors 559 /// <computation> 560 /// yield %1 : index 561 /// } : tensor<?xindex> 562 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> 563 /// 564 /// to just <computation> with %arg0 replaced by %c0. We only do this if the 565 /// tensor.generate operation has no side-effects. 566 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { 567 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 568 569 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 570 PatternRewriter &rewriter) const final { 571 auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>(); 572 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) 573 return failure(); 574 575 BlockAndValueMapping mapping; 576 Block *body = tensorFromElements.getBody(); 577 mapping.map(body->getArguments(), extract.indices()); 578 for (auto &op : body->without_terminator()) 579 rewriter.clone(op, mapping); 580 581 auto yield = cast<YieldOp>(body->getTerminator()); 582 583 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); 584 return success(); 585 } 586 }; 587 588 /// Canonicalizes the pattern of the form 589 /// 590 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> 591 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> 592 /// 593 /// to 594 /// 595 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> 596 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { 597 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 598 599 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 600 PatternRewriter &rewriter) const final { 601 auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); 602 if (!tensorCast) 603 return failure(); 604 605 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), 606 extract.indices()); 607 return success(); 608 } 609 }; 610 611 } // namespace 612 613 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, 614 MLIRContext *context) { 615 // TODO: Move extract patterns to tensor::ExtractOp. 616 results.add<ExtractFromTensorGenerate, ExtractFromTensorCast, 617 StaticTensorGenerate>(context); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // RankOp 622 //===----------------------------------------------------------------------===// 623 624 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 625 // Constant fold rank when the rank of the operand is known. 626 auto type = getOperand().getType(); 627 auto shapedType = type.dyn_cast<ShapedType>(); 628 if (shapedType && shapedType.hasRank()) 629 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 630 return IntegerAttr(); 631 } 632 633 //===----------------------------------------------------------------------===// 634 // ReshapeOp 635 //===----------------------------------------------------------------------===// 636 637 static int64_t getNumElements(ShapedType type) { 638 int64_t numElements = 1; 639 for (auto dim : type.getShape()) 640 numElements *= dim; 641 return numElements; 642 } 643 644 static LogicalResult verify(ReshapeOp op) { 645 TensorType operandType = op.source().getType().cast<TensorType>(); 646 TensorType resultType = op.result().getType().cast<TensorType>(); 647 648 if (operandType.getElementType() != resultType.getElementType()) 649 return op.emitOpError("element types of source and destination tensor " 650 "types should be the same"); 651 652 int64_t shapeSize = 653 op.shape().getType().cast<RankedTensorType>().getDimSize(0); 654 auto resultRankedType = resultType.dyn_cast<RankedTensorType>(); 655 auto operandRankedType = operandType.dyn_cast<RankedTensorType>(); 656 657 if (resultRankedType) { 658 if (operandRankedType && resultRankedType.hasStaticShape() && 659 operandRankedType.hasStaticShape()) { 660 if (getNumElements(operandRankedType) != getNumElements(resultRankedType)) 661 return op.emitOpError("source and destination tensor should have the " 662 "same number of elements"); 663 } 664 if (shapeSize == TensorType::kDynamicSize) 665 return op.emitOpError("cannot use shape operand with dynamic length to " 666 "reshape to statically-ranked tensor type"); 667 if (shapeSize != resultRankedType.getRank()) 668 return op.emitOpError( 669 "length of shape operand differs from the result's tensor rank"); 670 } 671 return success(); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // Reassociative reshape ops 676 //===----------------------------------------------------------------------===// 677 678 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 679 return getSymbolLessAffineMaps(getReassociationExprs()); 680 } 681 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 682 return convertReassociationIndicesToExprs(getContext(), 683 getReassociationIndices()); 684 } 685 686 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 687 return getSymbolLessAffineMaps(getReassociationExprs()); 688 } 689 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 690 return convertReassociationIndicesToExprs(getContext(), 691 getReassociationIndices()); 692 } 693 694 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 695 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 696 } 697 698 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 699 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 700 } 701 702 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`. 703 static RankedTensorType 704 computeTensorReshapeCollapsedType(RankedTensorType type, 705 ArrayRef<AffineMap> reassociation) { 706 auto shape = type.getShape(); 707 SmallVector<int64_t, 4> newShape; 708 newShape.reserve(reassociation.size()); 709 710 // Use the fact that reassociation is valid to simplify the logic: only use 711 // each map's rank. 712 assert(isReassociationValid(reassociation) && "invalid reassociation"); 713 unsigned currentDim = 0; 714 for (AffineMap m : reassociation) { 715 unsigned dim = m.getNumResults(); 716 auto band = shape.slice(currentDim, dim); 717 int64_t size = 1; 718 if (llvm::is_contained(band, ShapedType::kDynamicSize)) 719 size = ShapedType::kDynamicSize; 720 else 721 for (unsigned d = 0; d < dim; ++d) 722 size *= shape[currentDim + d]; 723 newShape.push_back(size); 724 currentDim += dim; 725 } 726 727 return RankedTensorType::get(newShape, type.getElementType()); 728 } 729 730 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 731 ArrayRef<ReassociationIndices> reassociation, 732 ArrayRef<NamedAttribute> attrs) { 733 auto resultType = computeTensorReshapeCollapsedType( 734 src.getType().cast<RankedTensorType>(), 735 getSymbolLessAffineMaps( 736 convertReassociationIndicesToExprs(b.getContext(), reassociation))); 737 build(b, result, resultType, src, attrs); 738 result.addAttribute(getReassociationAttrName(), 739 getReassociationIndicesAttribute(b, reassociation)); 740 } 741 742 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 743 ArrayRef<ReassociationIndices> reassociation, 744 ArrayRef<NamedAttribute> attrs) { 745 auto resultType = computeTensorReshapeCollapsedType( 746 src.getType().cast<RankedTensorType>(), 747 getSymbolLessAffineMaps( 748 convertReassociationIndicesToExprs(b.getContext(), reassociation))); 749 build(b, result, resultType, src, attrs); 750 result.addAttribute(getReassociationAttrName(), 751 getReassociationIndicesAttribute(b, reassociation)); 752 } 753 754 template <typename TensorReshapeOp, bool isExpansion = std::is_same< 755 TensorReshapeOp, ExpandShapeOp>::value> 756 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, 757 RankedTensorType expandedType, 758 RankedTensorType collapsedType) { 759 if (failed( 760 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 761 return failure(); 762 763 auto maps = op.getReassociationMaps(); 764 RankedTensorType expectedType = 765 computeTensorReshapeCollapsedType(expandedType, maps); 766 if (collapsedType != expectedType) 767 return op.emitOpError("expected collapsed type to be ") 768 << expectedType << ", but got " << collapsedType; 769 return success(); 770 } 771 772 static LogicalResult verify(ExpandShapeOp op) { 773 return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); 774 } 775 776 static LogicalResult verify(CollapseShapeOp op) { 777 return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); 778 } 779 780 namespace { 781 /// Reshape of a splat constant can be replaced with a constant of the result 782 /// type. 783 template <typename TensorReshapeOp> 784 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> { 785 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 786 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 787 PatternRewriter &rewriter) const override { 788 DenseElementsAttr attr; 789 if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) 790 return failure(); 791 if (!attr || !attr.isSplat()) 792 return failure(); 793 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( 794 reshapeOp.getResultType(), attr.getRawData(), true); 795 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr); 796 return success(); 797 } 798 }; 799 800 } // namespace 801 802 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 803 MLIRContext *context) { 804 results.add<CollapseReshapeOps<ExpandShapeOp>, 805 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>, 806 FoldReshapeWithConstant<ExpandShapeOp>>(context); 807 } 808 809 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 810 MLIRContext *context) { 811 results.add<CollapseReshapeOps<CollapseShapeOp>, 812 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 813 FoldReshapeWithConstant<CollapseShapeOp>>(context); 814 } 815 816 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 817 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 818 } 819 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 820 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // ExtractSliceOp 825 //===----------------------------------------------------------------------===// 826 827 /// An extract_slice op result type can be fully inferred from the source type 828 /// and the static representation of offsets, sizes and strides. Special 829 /// sentinels encode the dynamic case. 830 RankedTensorType 831 ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, 832 ArrayRef<int64_t> leadingStaticOffsets, 833 ArrayRef<int64_t> leadingStaticSizes, 834 ArrayRef<int64_t> leadingStaticStrides) { 835 // An extract_slice op may specify only a leading subset of offset/sizes/ 836 // strides in which case we complete with offset=0, sizes from memref type and 837 // strides=1. 838 unsigned rank = sourceRankedTensorType.getRank(); 839 assert(leadingStaticSizes.size() <= rank && 840 "unexpected leadingStaticSizes overflow"); 841 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 842 unsigned numTrailingSizes = rank - staticSizes.size(); 843 llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back( 844 numTrailingSizes)); 845 return RankedTensorType::get(staticSizes, 846 sourceRankedTensorType.getElementType()); 847 } 848 849 RankedTensorType 850 ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, 851 ArrayRef<OpFoldResult> leadingStaticOffsets, 852 ArrayRef<OpFoldResult> leadingStaticSizes, 853 ArrayRef<OpFoldResult> leadingStaticStrides) { 854 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 855 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 856 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 857 staticOffsets, ShapedType::kDynamicStrideOrOffset); 858 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 859 ShapedType::kDynamicSize); 860 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 861 staticStrides, ShapedType::kDynamicStrideOrOffset); 862 return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, 863 staticSizes, staticStrides); 864 } 865 866 /// An extract_slice op result type can be fully inferred from the source type 867 /// and the static representation of offsets, sizes and strides. Special 868 /// sentinels encode the dynamic case. 869 RankedTensorType ExtractSliceOp::inferRankReducedResultType( 870 unsigned resultRank, RankedTensorType sourceRankedTensorType, 871 ArrayRef<int64_t> leadingStaticOffsets, 872 ArrayRef<int64_t> leadingStaticSizes, 873 ArrayRef<int64_t> leadingStaticStrides) { 874 auto inferredType = 875 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 876 leadingStaticSizes, leadingStaticStrides) 877 .cast<RankedTensorType>(); 878 int rankDiff = inferredType.getRank() - resultRank; 879 if (rankDiff > 0) { 880 auto shape = inferredType.getShape(); 881 llvm::SmallDenseSet<unsigned> dimsToProject; 882 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 883 SmallVector<int64_t> projectedShape; 884 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 885 if (!dimsToProject.contains(pos)) 886 projectedShape.push_back(shape[pos]); 887 inferredType = 888 RankedTensorType::get(projectedShape, inferredType.getElementType()); 889 } 890 return inferredType; 891 } 892 893 RankedTensorType ExtractSliceOp::inferRankReducedResultType( 894 unsigned resultRank, RankedTensorType sourceRankedTensorType, 895 ArrayRef<OpFoldResult> leadingStaticOffsets, 896 ArrayRef<OpFoldResult> leadingStaticSizes, 897 ArrayRef<OpFoldResult> leadingStaticStrides) { 898 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 899 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 900 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 901 staticOffsets, ShapedType::kDynamicStrideOrOffset); 902 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 903 ShapedType::kDynamicSize); 904 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 905 staticStrides, ShapedType::kDynamicStrideOrOffset); 906 return ExtractSliceOp::inferRankReducedResultType( 907 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 908 staticStrides); 909 } 910 911 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom 912 /// result type. If the type passed is nullptr, it is inferred. 913 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 914 RankedTensorType resultType, Value source, 915 ArrayRef<OpFoldResult> offsets, 916 ArrayRef<OpFoldResult> sizes, 917 ArrayRef<OpFoldResult> strides, 918 ArrayRef<NamedAttribute> attrs) { 919 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 920 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 921 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 922 923 ShapedType::kDynamicStrideOrOffset); 924 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 925 ShapedType::kDynamicSize); 926 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 927 928 ShapedType::kDynamicStrideOrOffset); 929 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>(); 930 // Structuring implementation this way avoids duplication between builders. 931 if (!resultType) { 932 resultType = 933 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, 934 staticSizes, staticStrides) 935 .cast<RankedTensorType>(); 936 } 937 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 938 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 939 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 940 result.addAttributes(attrs); 941 } 942 943 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred 944 /// result type. 945 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 946 ArrayRef<OpFoldResult> offsets, 947 ArrayRef<OpFoldResult> sizes, 948 ArrayRef<OpFoldResult> strides, 949 ArrayRef<NamedAttribute> attrs) { 950 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 951 } 952 953 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the 954 /// type passed is nullptr, it is inferred. 955 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 956 RankedTensorType resultType, Value source, 957 ValueRange offsets, ValueRange sizes, 958 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 959 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 960 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 961 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 962 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 963 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 964 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 965 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 966 } 967 968 /// Build an ExtractSliceOp with dynamic entries and inferred result type. 969 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 970 ValueRange offsets, ValueRange sizes, 971 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 972 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 973 } 974 975 template <typename OpTy> 976 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, 977 OpTy op, Type expectedType) { 978 auto memrefType = expectedType.cast<ShapedType>(); 979 switch (result) { 980 case SliceVerificationResult::Success: 981 return success(); 982 case SliceVerificationResult::RankTooLarge: 983 return op.emitError("expected rank to be smaller or equal to ") 984 << "the other rank. "; 985 case SliceVerificationResult::SizeMismatch: 986 return op.emitError("expected type to be ") 987 << expectedType << " or a rank-reduced version. (size mismatch) "; 988 case SliceVerificationResult::ElemTypeMismatch: 989 return op.emitError("expected element type to be ") 990 << memrefType.getElementType(); 991 default: 992 llvm_unreachable("unexpected extract_slice op verification result"); 993 } 994 } 995 996 /// Verifier for ExtractSliceOp. 997 static LogicalResult verify(ExtractSliceOp op) { 998 // Verify result type against inferred type. 999 auto expectedType = 1000 ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(), 1001 op.getMixedSizes(), op.getMixedStrides()); 1002 auto result = 1003 isRankReducedType(expectedType.cast<ShapedType>(), op.getType()); 1004 return produceSliceErrorMsg(result, op, expectedType); 1005 } 1006 1007 /// Infer the canonical type of the result of an extract_slice op. Returns a 1008 /// type with rank `resultRank` that is either the rank of the rank-reduced 1009 /// type, or the non-rank-reduced type. 1010 static RankedTensorType 1011 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType, 1012 ArrayRef<OpFoldResult> mixedOffsets, 1013 ArrayRef<OpFoldResult> mixedSizes, 1014 ArrayRef<OpFoldResult> mixedStrides) { 1015 auto resultType = 1016 ExtractSliceOp::inferRankReducedResultType( 1017 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1018 .cast<RankedTensorType>(); 1019 if (resultType.getRank() != resultRank) { 1020 resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets, 1021 mixedSizes, mixedStrides) 1022 .cast<RankedTensorType>(); 1023 } 1024 return resultType; 1025 } 1026 1027 llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() { 1028 llvm::SmallDenseSet<unsigned> droppedDims; 1029 ArrayRef<int64_t> resultShape = getType().getShape(); 1030 SmallVector<OpFoldResult> mixedSizes = getMixedSizes(); 1031 unsigned shapePos = 0; 1032 for (auto size : enumerate(mixedSizes)) { 1033 Optional<int64_t> sizeVal = getConstantIntValue(size.value()); 1034 // If the size is not 1, or if the current matched dimension of the result 1035 // is the same static shape as the size value (which is 1), then the 1036 // dimension is preserved. 1037 if (!sizeVal || sizeVal.getValue() != 1 || 1038 (shapePos < resultShape.size() && resultShape[shapePos] == 1)) { 1039 shapePos++; 1040 continue; 1041 } 1042 droppedDims.insert(size.index()); 1043 } 1044 return droppedDims; 1045 } 1046 1047 LogicalResult ExtractSliceOp::reifyResultShapes( 1048 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1049 reifiedReturnShapes.resize(1); 1050 reifiedReturnShapes[0].reserve(getType().getRank()); 1051 SmallVector<OpFoldResult> mixedSizes = getMixedSizes(); 1052 llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims(); 1053 Location loc = getLoc(); 1054 for (auto size : enumerate(mixedSizes)) { 1055 if (droppedDims.count(size.index())) 1056 continue; 1057 if (auto attr = size.value().dyn_cast<Attribute>()) { 1058 reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>( 1059 loc, attr.cast<IntegerAttr>().getInt())); 1060 continue; 1061 } 1062 reifiedReturnShapes[0].push_back(size.value().get<Value>()); 1063 } 1064 return success(); 1065 } 1066 1067 namespace { 1068 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments. 1069 /// This essentially pushes memref_cast past its consuming slice when 1070 /// `canFoldIntoConsumerOp` is true. 1071 /// 1072 /// Example: 1073 /// ``` 1074 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32> 1075 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to 1076 /// tensor<3x4xf32> 1077 /// ``` 1078 /// is rewritten into: 1079 /// ``` 1080 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to 1081 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32> 1082 /// ``` 1083 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> { 1084 public: 1085 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 1086 1087 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 1088 PatternRewriter &rewriter) const override { 1089 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1090 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) { 1091 return matchPattern(operand, matchConstantIndex()); 1092 })) 1093 return failure(); 1094 1095 auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>(); 1096 if (!castOp) 1097 return failure(); 1098 1099 if (!canFoldIntoConsumerOp(castOp)) 1100 return failure(); 1101 1102 /// Deduce the type of the result to use for the canonicalized operation. 1103 RankedTensorType resultType = getCanonicalSliceResultType( 1104 sliceOp.getType().getRank(), sliceOp.getSourceType(), 1105 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), 1106 sliceOp.getMixedStrides()); 1107 Value newSlice = rewriter.create<ExtractSliceOp>( 1108 sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(), 1109 sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), 1110 sliceOp.static_sizes(), sliceOp.static_strides()); 1111 rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(), 1112 newSlice); 1113 return success(); 1114 } 1115 }; 1116 } // namespace 1117 1118 /// Return the canonical type of the result of an extract_slice op. 1119 struct SliceReturnTypeCanonicalizer { 1120 RankedTensorType operator()(ExtractSliceOp op, 1121 ArrayRef<OpFoldResult> mixedOffsets, 1122 ArrayRef<OpFoldResult> mixedSizes, 1123 ArrayRef<OpFoldResult> mixedStrides) { 1124 return getCanonicalSliceResultType(op.getType().getRank(), 1125 op.getSourceType(), mixedOffsets, 1126 mixedSizes, mixedStrides); 1127 } 1128 }; 1129 1130 /// A canonicalizer wrapper to replace ExtractSliceOps. 1131 struct SliceCanonicalizer { 1132 void operator()(PatternRewriter &rewriter, ExtractSliceOp op, 1133 ExtractSliceOp newOp) { 1134 Value replacement = newOp.getResult(); 1135 if (replacement.getType() != op.getType()) 1136 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), 1137 replacement); 1138 rewriter.replaceOp(op, replacement); 1139 } 1140 }; 1141 1142 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 1143 MLIRContext *context) { 1144 results.add< 1145 OpWithOffsetSizesAndStridesConstantArgumentFolder< 1146 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, 1147 ExtractSliceOpCastFolder>(context); 1148 } 1149 1150 // 1151 static LogicalResult 1152 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, 1153 ShapedType shapedType) { 1154 OpBuilder b(op.getContext()); 1155 for (OpFoldResult ofr : op.getMixedOffsets()) 1156 if (getConstantIntValue(ofr) != static_cast<int64_t>(0)) 1157 return failure(); 1158 // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip 1159 // is appropriate. 1160 auto shape = shapedType.getShape(); 1161 for (auto it : llvm::zip(op.getMixedSizes(), shape)) 1162 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it)) 1163 return failure(); 1164 for (OpFoldResult ofr : op.getMixedStrides()) 1165 if (getConstantIntValue(ofr) != static_cast<int64_t>(1)) 1166 return failure(); 1167 return success(); 1168 } 1169 1170 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, 1171 /// we can return the InsertSliceOp's source directly. 1172 // TODO: This only checks the immediate producer; extend to go up the 1173 // insert/extract chain if the slices are disjoint. 1174 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { 1175 auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>(); 1176 1177 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; 1178 if (insertOp && insertOp.source().getType() == extractOp.getType() && 1179 insertOp.isSameAs(extractOp, isSame)) 1180 return insertOp.source(); 1181 1182 return {}; 1183 } 1184 1185 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) { 1186 if (getSourceType() == getType() && 1187 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 1188 return this->source(); 1189 if (Value slice = foldExtractAfterInsertSlice(*this)) 1190 return slice; 1191 return OpFoldResult(); 1192 } 1193 1194 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp( 1195 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) { 1196 auto rankedTensorType = tensor.getType().cast<RankedTensorType>(); 1197 unsigned rank = rankedTensorType.getRank(); 1198 auto shape = rankedTensorType.getShape(); 1199 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 1200 SmallVector<OpFoldResult> sizes; 1201 for (unsigned i = 0, e = rank; i < e; ++i) { 1202 OpFoldResult dim; 1203 if (rankedTensorType.isDynamicDim(i)) 1204 dim = b.createOrFold<tensor::DimOp>( 1205 loc, tensor, b.create<arith::ConstantIndexOp>(loc, i)); 1206 else 1207 dim = b.getIndexAttr(shape[i]); 1208 sizes.push_back(dim); 1209 } 1210 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 1211 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor, 1212 offsets, sizes, strides); 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // InsertSliceOp 1217 //===----------------------------------------------------------------------===// 1218 1219 // Build a InsertSliceOp with mixed static and dynamic entries. 1220 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 1221 Value dest, ArrayRef<OpFoldResult> offsets, 1222 ArrayRef<OpFoldResult> sizes, 1223 ArrayRef<OpFoldResult> strides, 1224 ArrayRef<NamedAttribute> attrs) { 1225 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1226 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1227 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1228 1229 ShapedType::kDynamicStrideOrOffset); 1230 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1231 ShapedType::kDynamicSize); 1232 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1233 1234 ShapedType::kDynamicStrideOrOffset); 1235 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, 1236 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1237 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1238 result.addAttributes(attrs); 1239 } 1240 1241 // Build a InsertSliceOp with dynamic entries. 1242 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 1243 Value dest, ValueRange offsets, ValueRange sizes, 1244 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 1245 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1246 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1247 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1248 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1249 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1250 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1251 build(b, result, source, dest, offsetValues, sizeValues, strideValues); 1252 } 1253 1254 /// Verifier for InsertSliceOp. 1255 static LogicalResult verify(InsertSliceOp op) { 1256 // insert_slice is the inverse of extract_slice, use the same type inference. 1257 auto expectedType = ExtractSliceOp::inferRankReducedResultType( 1258 op.getSourceType().getRank(), op.getType(), 1259 extractFromI64ArrayAttr(op.static_offsets()), 1260 extractFromI64ArrayAttr(op.static_sizes()), 1261 extractFromI64ArrayAttr(op.static_strides())); 1262 auto result = 1263 isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType()); 1264 return produceSliceErrorMsg(result, op, expectedType); 1265 } 1266 1267 /// If we have two consecutive InsertSliceOp writing to the same slice, we 1268 /// can mutate the second InsertSliceOp's destination to the first one's. 1269 /// 1270 /// Example: 1271 /// 1272 /// ```mlir 1273 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1] 1274 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1] 1275 /// ``` 1276 /// 1277 /// folds into: 1278 /// 1279 /// ```mlir 1280 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1] 1281 /// ``` 1282 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { 1283 auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>(); 1284 1285 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; 1286 if (!prevInsertOp || 1287 prevInsertOp.source().getType() != insertOp.source().getType() || 1288 !prevInsertOp.isSameAs(insertOp, isSame)) 1289 return failure(); 1290 1291 insertOp.destMutable().assign(prevInsertOp.dest()); 1292 return success(); 1293 } 1294 1295 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) { 1296 if (getSourceType().hasStaticShape() && getType().hasStaticShape() && 1297 getSourceType() == getType() && 1298 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 1299 return this->source(); 1300 if (succeeded(foldInsertAfterInsertSlice(*this))) 1301 return getResult(); 1302 return OpFoldResult(); 1303 } 1304 1305 LogicalResult InsertSliceOp::reifyResultShapes( 1306 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1307 reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank())); 1308 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) { 1309 reifiedReturnShapes[0][dim] = 1310 builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim); 1311 } 1312 return success(); 1313 } 1314 1315 namespace { 1316 /// Pattern to rewrite a insert_slice op with constant arguments. 1317 class InsertSliceOpConstantArgumentFolder final 1318 : public OpRewritePattern<InsertSliceOp> { 1319 public: 1320 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 1321 1322 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 1323 PatternRewriter &rewriter) const override { 1324 // No constant operand, just return. 1325 if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { 1326 return matchPattern(operand, matchConstantIndex()); 1327 })) 1328 return failure(); 1329 1330 // At least one of offsets/sizes/strides is a new constant. 1331 // Form the new list of operands and constant attributes from the 1332 // existing. 1333 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets()); 1334 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes()); 1335 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides()); 1336 canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); 1337 canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); 1338 canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); 1339 1340 // Create the new op in canonical form. 1341 auto sourceType = ExtractSliceOp::inferRankReducedResultType( 1342 insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(), 1343 mixedOffsets, mixedSizes, mixedStrides); 1344 Value toInsert = insertSliceOp.source(); 1345 if (sourceType != insertSliceOp.getSourceType()) 1346 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(), 1347 sourceType, toInsert); 1348 rewriter.replaceOpWithNewOp<InsertSliceOp>( 1349 insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes, 1350 mixedStrides); 1351 return success(); 1352 } 1353 }; 1354 1355 /// Fold tensor_casts with insert_slice operations. If the source or destination 1356 /// tensor is a tensor_cast that removes static type information, the cast is 1357 /// folded into the insert_slice operation. E.g.: 1358 /// 1359 /// ```mlir 1360 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 1361 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ... 1362 /// ``` 1363 /// 1364 /// folds into: 1365 /// 1366 /// ```mlir 1367 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ... 1368 /// ``` 1369 /// 1370 /// Note: When folding a cast on the destination tensor, the result of the 1371 /// insert_slice operation is casted to ensure that the type of the result did 1372 /// not change. 1373 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> { 1374 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 1375 1376 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 1377 PatternRewriter &rewriter) const override { 1378 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) { 1379 return matchPattern(operand, matchConstantIndex()); 1380 })) 1381 return failure(); 1382 1383 auto getSourceOfCastOp = [](Value v) -> Optional<Value> { 1384 auto castOp = v.getDefiningOp<tensor::CastOp>(); 1385 if (!castOp || !canFoldIntoConsumerOp(castOp)) 1386 return llvm::None; 1387 return castOp.source(); 1388 }; 1389 Optional<Value> sourceCastSource = 1390 getSourceOfCastOp(insertSliceOp.source()); 1391 Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest()); 1392 if (!sourceCastSource && !destCastSource) 1393 return failure(); 1394 1395 Value replacement = rewriter.create<InsertSliceOp>( 1396 insertSliceOp.getLoc(), 1397 (sourceCastSource ? *sourceCastSource : insertSliceOp.source()), 1398 (destCastSource ? *destCastSource : insertSliceOp.dest()), 1399 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 1400 insertSliceOp.getMixedStrides()); 1401 1402 if (replacement.getType() != insertSliceOp.getType()) { 1403 replacement = rewriter.create<tensor::CastOp>( 1404 insertSliceOp.getLoc(), insertSliceOp.getType(), replacement); 1405 } 1406 rewriter.replaceOp(insertSliceOp, replacement); 1407 return success(); 1408 } 1409 }; 1410 1411 /// If additional static type information can be deduced from a insert_slice's 1412 /// size operands, insert an explicit cast of the op's source operand. This 1413 /// enables other canonicalization patterns that are matching for tensor_cast 1414 /// ops such as `ForOpTensorCastFolder` in SCF. 1415 /// 1416 /// Example: 1417 /// 1418 /// ```mlir 1419 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1] 1420 /// : tensor<?x?xf32> into ... 1421 /// ``` 1422 /// 1423 /// folds into: 1424 /// 1425 /// ```mlir 1426 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32> 1427 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1] 1428 /// : tensor<64x64xf32> into ... 1429 /// ``` 1430 struct InsertSliceOpSourceCastInserter final 1431 : public OpRewritePattern<InsertSliceOp> { 1432 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 1433 1434 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 1435 PatternRewriter &rewriter) const override { 1436 RankedTensorType srcType = insertSliceOp.getSourceType(); 1437 if (srcType.getRank() != insertSliceOp.getType().getRank()) 1438 return failure(); 1439 SmallVector<int64_t> newSrcShape(srcType.getShape().begin(), 1440 srcType.getShape().end()); 1441 for (int64_t i = 0; i < srcType.getRank(); ++i) { 1442 if (Optional<int64_t> constInt = 1443 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) 1444 newSrcShape[i] = *constInt; 1445 } 1446 1447 RankedTensorType newSrcType = 1448 RankedTensorType::get(newSrcShape, srcType.getElementType()); 1449 if (srcType == newSrcType || 1450 !preservesStaticInformation(srcType, newSrcType) || 1451 !tensor::CastOp::areCastCompatible(srcType, newSrcType)) 1452 return failure(); 1453 1454 // newSrcType is: 1455 // 1) Different from srcType. 1456 // 2) "More static" than srcType. 1457 // 3) Cast-compatible with srcType. 1458 // Insert the cast. 1459 Value cast = rewriter.create<tensor::CastOp>( 1460 insertSliceOp.getLoc(), newSrcType, insertSliceOp.source()); 1461 rewriter.replaceOpWithNewOp<InsertSliceOp>( 1462 insertSliceOp, cast, insertSliceOp.dest(), 1463 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 1464 insertSliceOp.getMixedStrides()); 1465 return success(); 1466 } 1467 }; 1468 } // namespace 1469 1470 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 1471 MLIRContext *context) { 1472 results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder, 1473 InsertSliceOpSourceCastInserter>(context); 1474 } 1475 1476 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, 1477 Location loc, 1478 Value tensor, 1479 Value dest) { 1480 auto rankedTensorType = dest.getType().cast<RankedTensorType>(); 1481 unsigned rank = rankedTensorType.getRank(); 1482 auto shape = rankedTensorType.getShape(); 1483 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 1484 SmallVector<OpFoldResult> sizes; 1485 for (unsigned i = 0, e = rank; i < e; ++i) { 1486 OpFoldResult dim; 1487 if (rankedTensorType.isDynamicDim(i)) 1488 dim = b.createOrFold<tensor::DimOp>( 1489 loc, dest, b.create<arith::ConstantIndexOp>(loc, i)); 1490 else 1491 dim = b.getIndexAttr(shape[i]); 1492 sizes.push_back(dim); 1493 } 1494 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 1495 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets, 1496 sizes, strides); 1497 } 1498 1499 //===----------------------------------------------------------------------===// 1500 // TableGen'd op method definitions 1501 //===----------------------------------------------------------------------===// 1502 1503 #define GET_OP_CLASSES 1504 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 1505