1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 10 #include "mlir/Dialect/Tensor/IR/Tensor.h" 11 #include "mlir/IR/BlockAndValueMapping.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "llvm/ADT/STLExtras.h" 17 18 using namespace mlir; 19 using namespace mlir::tensor; 20 21 //===----------------------------------------------------------------------===// 22 // CastOp 23 //===----------------------------------------------------------------------===// 24 25 /// Determines whether tensor::CastOp casts to a more dynamic version of the 26 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 27 /// implement canonicalization patterns for ops in different dialects that may 28 /// consume the results of tensor.cast operations. Such foldable tensor.cast 29 /// operations are typically inserted as `slice` ops and are canonicalized, 30 /// to preserve the type compatibility of their uses. 31 /// 32 /// Returns true when all conditions are met: 33 /// 1. source and result are ranked tensors with same element type and rank. 34 /// 2. the tensor type has more static information than the result 35 /// 36 /// Example: 37 /// ```mlir 38 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 39 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 40 /// ``` 41 /// 42 /// folds into: 43 /// 44 /// ```mlir 45 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 46 /// ``` 47 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 48 if (!castOp) 49 return false; 50 51 RankedTensorType sourceType = 52 castOp.source().getType().dyn_cast<RankedTensorType>(); 53 RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>(); 54 55 // Requires RankedTensorType. 56 if (!sourceType || !resultType) 57 return false; 58 59 // Requires same elemental type. 60 if (sourceType.getElementType() != resultType.getElementType()) 61 return false; 62 63 // Requires same rank. 64 if (sourceType.getRank() != resultType.getRank()) 65 return false; 66 67 // If cast is towards more static sizes along any dimension, don't fold. 68 for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { 69 if (ShapedType::isDynamic(std::get<0>(t)) && 70 !ShapedType::isDynamic(std::get<1>(t))) 71 return false; 72 } 73 74 return true; 75 } 76 77 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp 78 /// that can be folded. 79 LogicalResult mlir::tensor::foldTensorCast(Operation *op) { 80 bool folded = false; 81 for (OpOperand &operand : op->getOpOperands()) { 82 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 83 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 84 operand.set(castOp.getOperand()); 85 folded = true; 86 } 87 } 88 return success(folded); 89 } 90 91 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 92 if (inputs.size() != 1 || outputs.size() != 1) 93 return false; 94 Type a = inputs.front(), b = outputs.front(); 95 auto aT = a.dyn_cast<TensorType>(); 96 auto bT = b.dyn_cast<TensorType>(); 97 if (!aT || !bT) 98 return false; 99 100 if (aT.getElementType() != bT.getElementType()) 101 return false; 102 103 return succeeded(verifyCompatibleShape(aT, bT)); 104 } 105 106 /// Compute a TensorType that has the joined shape knowledge of the two 107 /// given TensorTypes. The element types need to match. 108 static TensorType joinShapes(TensorType one, TensorType two) { 109 assert(one.getElementType() == two.getElementType()); 110 111 if (!one.hasRank()) 112 return two; 113 if (!two.hasRank()) 114 return one; 115 116 int64_t rank = one.getRank(); 117 if (rank != two.getRank()) 118 return {}; 119 120 SmallVector<int64_t, 4> join; 121 join.reserve(rank); 122 for (int64_t i = 0; i < rank; ++i) { 123 if (one.isDynamicDim(i)) { 124 join.push_back(two.getDimSize(i)); 125 continue; 126 } 127 if (two.isDynamicDim(i)) { 128 join.push_back(one.getDimSize(i)); 129 continue; 130 } 131 if (one.getDimSize(i) != two.getDimSize(i)) 132 return {}; 133 join.push_back(one.getDimSize(i)); 134 } 135 return RankedTensorType::get(join, one.getElementType()); 136 } 137 138 namespace { 139 140 /// Replaces chains of two tensor.cast operations by a single tensor.cast 141 /// operation if doing so does not remove runtime constraints. 142 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 143 using OpRewritePattern<CastOp>::OpRewritePattern; 144 145 LogicalResult matchAndRewrite(CastOp tensorCast, 146 PatternRewriter &rewriter) const final { 147 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 148 149 if (!tensorCastOperand) 150 return failure(); 151 152 auto sourceType = 153 tensorCastOperand.getOperand().getType().cast<TensorType>(); 154 auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); 155 auto resultType = tensorCast.getType().cast<TensorType>(); 156 157 // We can remove the intermediate cast if joining all three produces the 158 // same result as just joining the source and result shapes. 159 auto firstJoin = 160 joinShapes(joinShapes(sourceType, intermediateType), resultType); 161 162 // The join might not exist if the cast sequence would fail at runtime. 163 if (!firstJoin) 164 return failure(); 165 166 // The newJoin always exists if the above join exists, it might just contain 167 // less information. If so, we cannot drop the intermediate cast, as doing 168 // so would remove runtime checks. 169 auto newJoin = joinShapes(sourceType, resultType); 170 if (firstJoin != newJoin) 171 return failure(); 172 173 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 174 tensorCastOperand.getOperand()); 175 return success(); 176 } 177 }; 178 179 } // namespace 180 181 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, 182 MLIRContext *context) { 183 results.add<ChainedTensorCast>(context); 184 } 185 186 //===----------------------------------------------------------------------===// 187 // ExtractOp 188 //===----------------------------------------------------------------------===// 189 190 static LogicalResult verify(ExtractOp op) { 191 // Verify the # indices match if we have a ranked type. 192 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 193 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 194 return op.emitOpError("incorrect number of indices for extract_element"); 195 196 return success(); 197 } 198 199 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 200 // The tensor operand must be a known constant. 201 Attribute tensor = operands.front(); 202 if (!tensor) 203 return {}; 204 // If this is a splat elements attribute, simply return the value. All of the 205 // elements of a splat attribute are the same. 206 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 207 return splatTensor.getSplatValue(); 208 209 // Otherwise, collect the constant indices into the tensor. 210 SmallVector<uint64_t, 8> indices; 211 for (Attribute indice : llvm::drop_begin(operands, 1)) { 212 if (!indice || !indice.isa<IntegerAttr>()) 213 return {}; 214 indices.push_back(indice.cast<IntegerAttr>().getInt()); 215 } 216 217 // If this is an elements attribute, query the value at the given indices. 218 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 219 if (elementsAttr && elementsAttr.isValidIndex(indices)) 220 return elementsAttr.getValue(indices); 221 return {}; 222 } 223 224 //===----------------------------------------------------------------------===// 225 // FromElementsOp 226 //===----------------------------------------------------------------------===// 227 228 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 229 Type elementType, ValueRange elements) { 230 Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())}, 231 elementType); 232 result.addOperands(elements); 233 result.addTypes(resultTy); 234 } 235 236 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 237 ValueRange elements) { 238 assert(!elements.empty() && "expected at least one element"); 239 build(builder, result, elements.front().getType(), elements); 240 } 241 242 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) { 243 if (!llvm::is_contained(operands, nullptr)) 244 return DenseElementsAttr::get(getType(), operands); 245 return {}; 246 } 247 248 namespace { 249 250 // Canonicalizes the pattern of the form 251 // 252 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 253 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> 254 // 255 // to just %element. 256 struct ExtractElementFromTensorFromElements 257 : public OpRewritePattern<tensor::ExtractOp> { 258 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 259 260 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 261 PatternRewriter &rewriter) const final { 262 if (extract.indices().size() != 1) 263 return failure(); 264 265 auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>(); 266 if (tensorFromElements == nullptr) 267 return failure(); 268 269 APInt index; 270 if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) 271 return failure(); 272 // Prevent out of bounds accesses. This can happen in invalid code that will 273 // never execute. 274 if (tensorFromElements->getNumOperands() <= index.getZExtValue() || 275 index.getSExtValue() < 0) 276 return failure(); 277 rewriter.replaceOp(extract, 278 tensorFromElements.getOperand(index.getZExtValue())); 279 return success(); 280 } 281 }; 282 283 } // namespace 284 285 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, 286 MLIRContext *context) { 287 results.add<ExtractElementFromTensorFromElements>(context); 288 } 289 290 //===----------------------------------------------------------------------===// 291 // InsertOp 292 //===----------------------------------------------------------------------===// 293 294 static LogicalResult verify(InsertOp op) { 295 // Verify the # indices match if we have a ranked type. 296 if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>()) 297 if (destType.getRank() != static_cast<int64_t>(op.indices().size())) 298 return op.emitOpError("incorrect number of indices"); 299 return success(); 300 } 301 302 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) { 303 Attribute scalar = operands[0]; 304 Attribute dest = operands[1]; 305 if (scalar && dest) 306 if (auto splatDest = dest.dyn_cast<SplatElementsAttr>()) 307 if (scalar == splatDest.getSplatValue()) 308 return dest; 309 return {}; 310 } 311 312 //===----------------------------------------------------------------------===// 313 // GenerateOp 314 //===----------------------------------------------------------------------===// 315 316 static LogicalResult verify(GenerateOp op) { 317 // Ensure that the tensor type has as many dynamic dimensions as are specified 318 // by the operands. 319 RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); 320 if (op.getNumOperands() != resultTy.getNumDynamicDims()) 321 return op.emitError("must have as many index operands as dynamic extents " 322 "in the result type"); 323 324 // Ensure that region arguments span the index space. 325 if (!llvm::all_of(op.body().getArgumentTypes(), 326 [](Type ty) { return ty.isIndex(); })) 327 return op.emitError("all body arguments must be index"); 328 if (op.body().getNumArguments() != resultTy.getRank()) 329 return op.emitError("must have one body argument per input dimension"); 330 331 // Ensure that the region yields an element of the right type. 332 auto yieldOp = 333 llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); 334 if (yieldOp.value().getType() != resultTy.getElementType()) 335 return op.emitOpError( 336 "body must be terminated with a `yield` operation of the tensor " 337 "element type"); 338 339 return success(); 340 } 341 342 void GenerateOp::build( 343 OpBuilder &b, OperationState &result, Type resultTy, 344 ValueRange dynamicExtents, 345 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 346 build(b, result, resultTy, dynamicExtents); 347 348 // Build and populate body. 349 OpBuilder::InsertionGuard guard(b); 350 Region *bodyRegion = result.regions.front().get(); 351 auto rank = resultTy.cast<RankedTensorType>().getRank(); 352 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); 353 Block *bodyBlock = 354 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); 355 bodyBuilder(b, result.location, bodyBlock->getArguments()); 356 } 357 358 namespace { 359 360 /// Canonicalizes tensor.generate operations with a constant 361 /// operand into the equivalent operation with the operand expressed in the 362 /// result type, instead. We also insert a type cast to make sure that the 363 /// resulting IR is still well-typed. 364 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { 365 using OpRewritePattern<GenerateOp>::OpRewritePattern; 366 367 LogicalResult matchAndRewrite(GenerateOp tensorFromElements, 368 PatternRewriter &rewriter) const final { 369 auto resultType = 370 tensorFromElements.getResult().getType().cast<RankedTensorType>(); 371 372 if (resultType.hasStaticShape()) 373 return failure(); 374 375 SmallVector<Value, 4> newOperands; 376 SmallVector<int64_t, 4> newShape; 377 auto operandsIt = tensorFromElements.dynamicExtents().begin(); 378 379 for (int64_t dim : resultType.getShape()) { 380 if (dim != RankedTensorType::kDynamicSize) { 381 newShape.push_back(dim); 382 continue; 383 } 384 APInt index; 385 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { 386 newShape.push_back(RankedTensorType::kDynamicSize); 387 newOperands.push_back(*operandsIt++); 388 continue; 389 } 390 newShape.push_back(index.getSExtValue()); 391 operandsIt++; 392 } 393 394 if (newOperands.size() == tensorFromElements.dynamicExtents().size()) 395 return failure(); 396 397 auto loc = tensorFromElements.getLoc(); 398 auto newOp = rewriter.create<GenerateOp>( 399 loc, RankedTensorType::get(newShape, resultType.getElementType()), 400 newOperands); 401 rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), 402 newOp.body().begin()); 403 rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, 404 newOp); 405 return success(); 406 } 407 }; 408 409 /// Canonicalizes the pattern of the form 410 /// 411 /// %tensor = tensor.generate %x { 412 /// ^bb0(%arg0: index): // no predecessors 413 /// <computation> 414 /// yield %1 : index 415 /// } : tensor<?xindex> 416 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> 417 /// 418 /// to just <computation> with %arg0 replaced by %c0. We only do this if the 419 /// tensor.generate operation has no side-effects. 420 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { 421 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 422 423 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 424 PatternRewriter &rewriter) const final { 425 auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>(); 426 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) 427 return failure(); 428 429 BlockAndValueMapping mapping; 430 Block *body = tensorFromElements.getBody(); 431 mapping.map(body->getArguments(), extract.indices()); 432 for (auto &op : body->without_terminator()) 433 rewriter.clone(op, mapping); 434 435 auto yield = cast<YieldOp>(body->getTerminator()); 436 437 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); 438 return success(); 439 } 440 }; 441 442 /// Canonicalizes the pattern of the form 443 /// 444 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> 445 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> 446 /// 447 /// to 448 /// 449 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> 450 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { 451 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 452 453 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 454 PatternRewriter &rewriter) const final { 455 auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); 456 if (!tensorCast) 457 return failure(); 458 459 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), 460 extract.indices()); 461 return success(); 462 } 463 }; 464 465 } // namespace 466 467 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, 468 MLIRContext *context) { 469 // TODO: Move extract patterns to tensor::ExtractOp. 470 results.add<ExtractFromTensorGenerate, ExtractFromTensorCast, 471 StaticTensorGenerate>(context); 472 } 473 474 //===----------------------------------------------------------------------===// 475 // ReshapeOp 476 //===----------------------------------------------------------------------===// 477 478 static int64_t GetNumElements(ShapedType type) { 479 int64_t numElements = 1; 480 for (auto dim : type.getShape()) 481 numElements *= dim; 482 return numElements; 483 } 484 485 static LogicalResult verify(ReshapeOp op) { 486 TensorType operandType = op.source().getType().cast<TensorType>(); 487 TensorType resultType = op.result().getType().cast<TensorType>(); 488 489 if (operandType.getElementType() != resultType.getElementType()) 490 return op.emitOpError("element types of source and destination tensor " 491 "types should be the same"); 492 493 int64_t shapeSize = 494 op.shape().getType().cast<RankedTensorType>().getDimSize(0); 495 auto resultRankedType = resultType.dyn_cast<RankedTensorType>(); 496 auto operandRankedType = operandType.dyn_cast<RankedTensorType>(); 497 498 if (resultRankedType) { 499 if (operandRankedType && resultRankedType.hasStaticShape() && 500 operandRankedType.hasStaticShape()) { 501 if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType)) 502 return op.emitOpError("source and destination tensor should have the " 503 "same number of elements"); 504 } 505 if (shapeSize == TensorType::kDynamicSize) 506 return op.emitOpError("cannot use shape operand with dynamic length to " 507 "reshape to statically-ranked tensor type"); 508 if (shapeSize != resultRankedType.getRank()) 509 return op.emitOpError( 510 "length of shape operand differs from the result's tensor rank"); 511 } 512 return success(); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // ExtractSliceOp 517 //===----------------------------------------------------------------------===// 518 519 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 520 /// it is a Value or into `staticVec` if it is an IntegerAttr. 521 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 522 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 523 /// come from an AttrSizedOperandSegments trait. 524 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 525 SmallVectorImpl<Value> &dynamicVec, 526 SmallVectorImpl<int64_t> &staticVec, 527 int64_t sentinel) { 528 if (auto v = ofr.dyn_cast<Value>()) { 529 dynamicVec.push_back(v); 530 staticVec.push_back(sentinel); 531 return; 532 } 533 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 534 staticVec.push_back(apInt.getSExtValue()); 535 } 536 537 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 538 SmallVectorImpl<Value> &dynamicVec, 539 SmallVectorImpl<int64_t> &staticVec, 540 int64_t sentinel) { 541 for (auto ofr : ofrs) 542 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 543 } 544 545 /// An extract_slice op result type can be fully inferred from the source type 546 /// and the static representation of offsets, sizes and strides. Special 547 /// sentinels encode the dynamic case. 548 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, 549 ArrayRef<int64_t> leadingStaticOffsets, 550 ArrayRef<int64_t> leadingStaticSizes, 551 ArrayRef<int64_t> leadingStaticStrides) { 552 // An extract_slice op may specify only a leading subset of offset/sizes/ 553 // strides in which case we complete with offset=0, sizes from memref type and 554 // strides=1. 555 unsigned rank = sourceRankedTensorType.getRank(); 556 assert(leadingStaticSizes.size() <= rank && 557 "unexpected leadingStaticSizes overflow"); 558 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 559 unsigned numTrailingSizes = rank - staticSizes.size(); 560 llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back( 561 numTrailingSizes)); 562 return RankedTensorType::get(staticSizes, 563 sourceRankedTensorType.getElementType()); 564 } 565 566 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 567 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 568 return llvm::to_vector<4>( 569 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 570 return a.cast<IntegerAttr>().getInt(); 571 })); 572 } 573 574 Type ExtractSliceOp::inferResultType( 575 RankedTensorType sourceRankedTensorType, 576 ArrayRef<OpFoldResult> leadingStaticOffsets, 577 ArrayRef<OpFoldResult> leadingStaticSizes, 578 ArrayRef<OpFoldResult> leadingStaticStrides) { 579 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 580 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 581 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 582 staticOffsets, ShapedType::kDynamicStrideOrOffset); 583 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 584 ShapedType::kDynamicSize); 585 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 586 staticStrides, ShapedType::kDynamicStrideOrOffset); 587 return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, 588 staticSizes, staticStrides); 589 } 590 591 /// An extract_slice op result type can be fully inferred from the source type 592 /// and the static representation of offsets, sizes and strides. Special 593 /// sentinels encode the dynamic case. 594 Type ExtractSliceOp::inferRankReducedResultType( 595 unsigned resultRank, RankedTensorType sourceRankedTensorType, 596 ArrayRef<int64_t> leadingStaticOffsets, 597 ArrayRef<int64_t> leadingStaticSizes, 598 ArrayRef<int64_t> leadingStaticStrides) { 599 auto inferredType = 600 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 601 leadingStaticSizes, leadingStaticStrides) 602 .cast<RankedTensorType>(); 603 int rankDiff = inferredType.getRank() - resultRank; 604 if (rankDiff > 0) { 605 auto shape = inferredType.getShape(); 606 llvm::SmallDenseSet<unsigned> dimsToProject; 607 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 608 SmallVector<int64_t> projectedShape; 609 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 610 if (!dimsToProject.contains(pos)) 611 projectedShape.push_back(shape[pos]); 612 inferredType = 613 RankedTensorType::get(projectedShape, inferredType.getElementType()); 614 } 615 return inferredType; 616 } 617 618 Type ExtractSliceOp::inferRankReducedResultType( 619 unsigned resultRank, RankedTensorType sourceRankedTensorType, 620 ArrayRef<OpFoldResult> leadingStaticOffsets, 621 ArrayRef<OpFoldResult> leadingStaticSizes, 622 ArrayRef<OpFoldResult> leadingStaticStrides) { 623 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 624 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 625 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 626 staticOffsets, ShapedType::kDynamicStrideOrOffset); 627 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 628 ShapedType::kDynamicSize); 629 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 630 staticStrides, ShapedType::kDynamicStrideOrOffset); 631 return ExtractSliceOp::inferRankReducedResultType( 632 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 633 staticStrides); 634 } 635 636 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom 637 /// result type. If the type passed is nullptr, it is inferred. 638 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 639 RankedTensorType resultType, Value source, 640 ArrayRef<OpFoldResult> offsets, 641 ArrayRef<OpFoldResult> sizes, 642 ArrayRef<OpFoldResult> strides, 643 ArrayRef<NamedAttribute> attrs) { 644 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 645 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 646 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 647 ShapedType::kDynamicStrideOrOffset); 648 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 649 ShapedType::kDynamicSize); 650 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 651 ShapedType::kDynamicStrideOrOffset); 652 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>(); 653 // Structuring implementation this way avoids duplication between builders. 654 if (!resultType) { 655 resultType = 656 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, 657 staticSizes, staticStrides) 658 .cast<RankedTensorType>(); 659 } 660 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 661 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 662 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 663 result.addAttributes(attrs); 664 } 665 666 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred 667 /// result type. 668 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 669 ArrayRef<OpFoldResult> offsets, 670 ArrayRef<OpFoldResult> sizes, 671 ArrayRef<OpFoldResult> strides, 672 ArrayRef<NamedAttribute> attrs) { 673 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 674 } 675 676 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the 677 /// type passed is nullptr, it is inferred. 678 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 679 RankedTensorType resultType, Value source, 680 ValueRange offsets, ValueRange sizes, 681 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 682 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 683 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 684 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 685 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 686 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 687 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 688 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 689 } 690 691 /// Build an ExtractSliceOp with dynamic entries and inferred result type. 692 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 693 ValueRange offsets, ValueRange sizes, 694 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 695 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 696 } 697 698 enum SliceVerificationResult { 699 Success, 700 RankTooLarge, 701 SizeMismatch, 702 ElemTypeMismatch, 703 }; 704 705 /// Checks if `original` Type type can be rank reduced to `reduced` type. 706 /// This function is slight variant of `is subsequence` algorithm where 707 /// not matching dimension must be 1. 708 static SliceVerificationResult 709 isRankReducedType(Type originalType, Type candidateReducedType, 710 std::string *errMsg = nullptr) { 711 if (originalType == candidateReducedType) 712 return SliceVerificationResult::Success; 713 if (!originalType.isa<RankedTensorType>()) 714 return SliceVerificationResult::Success; 715 if (originalType.isa<RankedTensorType>() && 716 !candidateReducedType.isa<RankedTensorType>()) 717 return SliceVerificationResult::Success; 718 719 ShapedType originalShapedType = originalType.cast<ShapedType>(); 720 ShapedType candidateReducedShapedType = 721 candidateReducedType.cast<ShapedType>(); 722 723 // Rank and size logic is valid for all ShapedTypes. 724 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 725 ArrayRef<int64_t> candidateReducedShape = 726 candidateReducedShapedType.getShape(); 727 unsigned originalRank = originalShape.size(), 728 candidateReducedRank = candidateReducedShape.size(); 729 if (candidateReducedRank > originalRank) 730 return SliceVerificationResult::RankTooLarge; 731 732 auto optionalUnusedDimsMask = 733 computeRankReductionMask(originalShape, candidateReducedShape); 734 735 // Sizes cannot be matched in case empty vector is returned. 736 if (!optionalUnusedDimsMask.hasValue()) 737 return SliceVerificationResult::SizeMismatch; 738 739 if (originalShapedType.getElementType() != 740 candidateReducedShapedType.getElementType()) 741 return SliceVerificationResult::ElemTypeMismatch; 742 743 // We are done for the tensor case. 744 if (originalType.isa<RankedTensorType>()) 745 return SliceVerificationResult::Success; 746 747 return SliceVerificationResult::Success; 748 } 749 750 template <typename OpTy> 751 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, 752 OpTy op, Type expectedType, 753 StringRef errMsg = "") { 754 auto memrefType = expectedType.cast<ShapedType>(); 755 switch (result) { 756 case SliceVerificationResult::Success: 757 return success(); 758 case SliceVerificationResult::RankTooLarge: 759 return op.emitError("expected result rank to be smaller or equal to ") 760 << "the source rank. " << errMsg; 761 case SliceVerificationResult::SizeMismatch: 762 return op.emitError("expected result type to be ") 763 << expectedType 764 << " or a rank-reduced version. (mismatch of result sizes) " 765 << errMsg; 766 case SliceVerificationResult::ElemTypeMismatch: 767 return op.emitError("expected result element type to be ") 768 << memrefType.getElementType() << errMsg; 769 } 770 llvm_unreachable("unexpected extract_slice op verification result"); 771 } 772 773 /// Verifier for ExtractSliceOp. 774 static LogicalResult verify(ExtractSliceOp op) { 775 // Verify result type against inferred type. 776 auto expectedType = ExtractSliceOp::inferResultType( 777 op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), 778 extractFromI64ArrayAttr(op.static_sizes()), 779 extractFromI64ArrayAttr(op.static_strides())); 780 auto result = isRankReducedType(expectedType, op.getType()); 781 return produceSliceErrorMsg(result, op, expectedType); 782 } 783 784 /// Infer the canonical type of the result of an extract_slice op. Returns a 785 /// type with rank `resultRank` that is either the rank of the rank-reduced 786 /// type, or the non-rank-reduced type. 787 static RankedTensorType 788 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType, 789 ArrayRef<OpFoldResult> mixedOffsets, 790 ArrayRef<OpFoldResult> mixedSizes, 791 ArrayRef<OpFoldResult> mixedStrides) { 792 auto resultType = 793 ExtractSliceOp::inferRankReducedResultType( 794 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 795 .cast<RankedTensorType>(); 796 if (resultType.getRank() != resultRank) { 797 resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets, 798 mixedSizes, mixedStrides) 799 .cast<RankedTensorType>(); 800 } 801 return resultType; 802 } 803 804 namespace { 805 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments. 806 /// This essentially pushes memref_cast past its consuming slice when 807 /// `canFoldIntoConsumerOp` is true. 808 /// 809 /// Example: 810 /// ``` 811 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32> 812 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to 813 /// tensor<3x4xf32> 814 /// ``` 815 /// is rewritten into: 816 /// ``` 817 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to 818 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32> 819 /// ``` 820 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> { 821 public: 822 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 823 824 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 825 PatternRewriter &rewriter) const override { 826 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 827 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) { 828 return matchPattern(operand, matchConstantIndex()); 829 })) 830 return failure(); 831 832 auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>(); 833 if (!castOp) 834 return failure(); 835 836 if (!canFoldIntoConsumerOp(castOp)) 837 return failure(); 838 839 /// Deduce the type of the result to use for the canonicalized operation. 840 RankedTensorType resultType = getCanonicalSliceResultType( 841 sliceOp.getType().getRank(), sliceOp.getSourceType(), 842 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), 843 sliceOp.getMixedStrides()); 844 Value newSlice = rewriter.create<ExtractSliceOp>( 845 sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(), 846 sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), 847 sliceOp.static_sizes(), sliceOp.static_strides()); 848 rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(), 849 newSlice); 850 return success(); 851 } 852 }; 853 } // namespace 854 855 /// Return the canonical type of the result of an extract_slice op. 856 struct SliceReturnTypeCanonicalizer { 857 RankedTensorType operator()(ExtractSliceOp op, 858 ArrayRef<OpFoldResult> mixedOffsets, 859 ArrayRef<OpFoldResult> mixedSizes, 860 ArrayRef<OpFoldResult> mixedStrides) { 861 return getCanonicalSliceResultType(op.getType().getRank(), 862 op.getSourceType(), mixedOffsets, 863 mixedSizes, mixedStrides); 864 } 865 }; 866 867 /// A canonicalizer wrapper to replace ExtractSliceOps. 868 struct SliceCanonicalizer { 869 void operator()(PatternRewriter &rewriter, ExtractSliceOp op, 870 ExtractSliceOp newOp) { 871 Value replacement = newOp.getResult(); 872 if (replacement.getType() != op.getType()) 873 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), 874 replacement); 875 rewriter.replaceOp(op, replacement); 876 } 877 }; 878 879 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 880 MLIRContext *context) { 881 results.add< 882 OpWithOffsetSizesAndStridesConstantArgumentFolder< 883 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, 884 ExtractSliceOpCastFolder>(context); 885 } 886 887 // 888 static LogicalResult 889 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, 890 ShapedType shapedType) { 891 OpBuilder b(op.getContext()); 892 for (OpFoldResult ofr : op.getMixedOffsets()) 893 if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0))) 894 return failure(); 895 // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip 896 // is appropriate. 897 auto shape = shapedType.getShape(); 898 for (auto it : llvm::zip(op.getMixedSizes(), shape)) 899 if (!isEqualConstantIntOrValue(std::get<0>(it), 900 b.getIndexAttr(std::get<1>(it)))) 901 return failure(); 902 for (OpFoldResult ofr : op.getMixedStrides()) 903 if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1))) 904 return failure(); 905 return success(); 906 } 907 908 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) { 909 if (getSourceType() == getType() && 910 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 911 return this->source(); 912 return OpFoldResult(); 913 } 914 915 //===----------------------------------------------------------------------===// 916 // InsertSliceOp 917 //===----------------------------------------------------------------------===// 918 919 // Build a InsertSliceOp with mixed static and dynamic entries. 920 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 921 Value dest, ArrayRef<OpFoldResult> offsets, 922 ArrayRef<OpFoldResult> sizes, 923 ArrayRef<OpFoldResult> strides, 924 ArrayRef<NamedAttribute> attrs) { 925 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 926 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 927 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 928 ShapedType::kDynamicStrideOrOffset); 929 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 930 ShapedType::kDynamicSize); 931 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 932 ShapedType::kDynamicStrideOrOffset); 933 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, 934 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 935 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 936 result.addAttributes(attrs); 937 } 938 939 // Build a InsertSliceOp with dynamic entries. 940 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 941 Value dest, ValueRange offsets, ValueRange sizes, 942 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 943 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 944 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 945 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 946 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 947 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 948 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 949 build(b, result, source, dest, offsetValues, sizeValues, strideValues); 950 } 951 952 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) { 953 if (getSourceType().hasStaticShape() && getType().hasStaticShape() && 954 getSourceType() == getType() && 955 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 956 return this->source(); 957 return OpFoldResult(); 958 } 959 960 namespace { 961 /// Pattern to rewrite a insert_slice op with constant arguments. 962 class InsertSliceOpConstantArgumentFolder final 963 : public OpRewritePattern<InsertSliceOp> { 964 public: 965 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 966 967 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 968 PatternRewriter &rewriter) const override { 969 // No constant operand, just return. 970 if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { 971 return matchPattern(operand, matchConstantIndex()); 972 })) 973 return failure(); 974 975 // At least one of offsets/sizes/strides is a new constant. 976 // Form the new list of operands and constant attributes from the 977 // existing. 978 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets()); 979 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes()); 980 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides()); 981 canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); 982 canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); 983 canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); 984 985 // Create the new op in canonical form. 986 rewriter.replaceOpWithNewOp<InsertSliceOp>( 987 insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(), 988 mixedOffsets, mixedSizes, mixedStrides); 989 return success(); 990 } 991 }; 992 993 /// Fold tensor_casts with insert_slice operations. 994 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> { 995 using OpRewritePattern<InsertSliceOp>::OpRewritePattern; 996 997 LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, 998 PatternRewriter &rewriter) const override { 999 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) { 1000 return matchPattern(operand, matchConstantIndex()); 1001 })) 1002 return failure(); 1003 1004 auto getSourceOfCastOp = [](Value v) -> Optional<Value> { 1005 auto castOp = v.getDefiningOp<tensor::CastOp>(); 1006 if (!castOp || !canFoldIntoConsumerOp(castOp)) 1007 return llvm::None; 1008 return castOp.source(); 1009 }; 1010 Optional<Value> sourceCastSource = 1011 getSourceOfCastOp(insertSliceOp.source()); 1012 Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest()); 1013 if (!sourceCastSource && !destCastSource) 1014 return failure(); 1015 1016 Value replacement = rewriter.create<InsertSliceOp>( 1017 insertSliceOp.getLoc(), 1018 (sourceCastSource ? *sourceCastSource : insertSliceOp.source()), 1019 (destCastSource ? *destCastSource : insertSliceOp.dest()), 1020 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 1021 insertSliceOp.getMixedStrides()); 1022 1023 if (replacement.getType() != insertSliceOp.getType()) { 1024 replacement = rewriter.create<tensor::CastOp>( 1025 insertSliceOp.getLoc(), insertSliceOp.getType(), replacement); 1026 } 1027 rewriter.replaceOp(insertSliceOp, replacement); 1028 return success(); 1029 } 1030 }; 1031 } // namespace 1032 1033 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 1034 MLIRContext *context) { 1035 results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>( 1036 context); 1037 } 1038 1039 //===----------------------------------------------------------------------===// 1040 // TableGen'd op method definitions 1041 //===----------------------------------------------------------------------===// 1042 1043 #define GET_OP_CLASSES 1044 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 1045