1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/Dialect/Utils/StaticValueUtils.h"
13 #include "mlir/IR/BlockAndValueMapping.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 using namespace mlir;
21 using namespace mlir::tensor;
22 
23 /// Materialize a single constant operation from a given attribute value with
24 /// the desired resultant type.
25 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
26                                               Attribute value, Type type,
27                                               Location loc) {
28   if (arith::ConstantOp::isBuildableWith(value, type))
29     return builder.create<arith::ConstantOp>(loc, value, type);
30   if (ConstantOp::isBuildableWith(value, type))
31     return builder.create<ConstantOp>(loc, value, type);
32   return nullptr;
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // CastOp
37 //===----------------------------------------------------------------------===//
38 
39 /// Returns true if `target` is a ranked tensor type that preserves static
40 /// information available in the `source` ranked tensor type.
41 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
42   auto sourceType = source.dyn_cast<RankedTensorType>();
43   auto targetType = target.dyn_cast<RankedTensorType>();
44 
45   // Requires RankedTensorType.
46   if (!sourceType || !targetType)
47     return false;
48 
49   // Requires same elemental type.
50   if (sourceType.getElementType() != targetType.getElementType())
51     return false;
52 
53   // Requires same rank.
54   if (sourceType.getRank() != targetType.getRank())
55     return false;
56 
57   // If cast is towards more static sizes along any dimension, don't fold.
58   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
59     if (!ShapedType::isDynamic(std::get<0>(t)) &&
60         ShapedType::isDynamic(std::get<1>(t)))
61       return false;
62   }
63 
64   return true;
65 }
66 
67 /// Determines whether tensor::CastOp casts to a more dynamic version of the
68 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
69 /// implement canonicalization patterns for ops in different dialects that may
70 /// consume the results of tensor.cast operations. Such foldable tensor.cast
71 /// operations are typically inserted as `slice` ops and are canonicalized,
72 /// to preserve the type compatibility of their uses.
73 ///
74 /// Returns true when all conditions are met:
75 /// 1. source and result are ranked tensors with same element type and rank.
76 /// 2. the tensor type has more static information than the result
77 ///
78 /// Example:
79 /// ```mlir
80 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
81 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
82 /// ```
83 ///
84 /// folds into:
85 ///
86 /// ```mlir
87 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
88 /// ```
89 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
90   if (!castOp)
91     return false;
92 
93   // Can fold if the source of cast has at least as much static information as
94   // its results.
95   return preservesStaticInformation(castOp.getType(),
96                                     castOp.source().getType());
97 }
98 
99 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
100 /// that can be folded.
101 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
102   bool folded = false;
103   for (OpOperand &operand : op->getOpOperands()) {
104     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
105     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
106       operand.set(castOp.getOperand());
107       folded = true;
108     }
109   }
110   return success(folded);
111 }
112 
113 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
114   if (inputs.size() != 1 || outputs.size() != 1)
115     return false;
116   Type a = inputs.front(), b = outputs.front();
117   auto aT = a.dyn_cast<TensorType>();
118   auto bT = b.dyn_cast<TensorType>();
119   if (!aT || !bT)
120     return false;
121 
122   if (aT.getElementType() != bT.getElementType())
123     return false;
124 
125   return succeeded(verifyCompatibleShape(aT, bT));
126 }
127 
128 /// Compute a TensorType that has the joined shape knowledge of the two
129 /// given TensorTypes. The element types need to match.
130 static TensorType joinShapes(TensorType one, TensorType two) {
131   assert(one.getElementType() == two.getElementType());
132 
133   if (!one.hasRank())
134     return two;
135   if (!two.hasRank())
136     return one;
137 
138   int64_t rank = one.getRank();
139   if (rank != two.getRank())
140     return {};
141 
142   SmallVector<int64_t, 4> join;
143   join.reserve(rank);
144   for (int64_t i = 0; i < rank; ++i) {
145     if (one.isDynamicDim(i)) {
146       join.push_back(two.getDimSize(i));
147       continue;
148     }
149     if (two.isDynamicDim(i)) {
150       join.push_back(one.getDimSize(i));
151       continue;
152     }
153     if (one.getDimSize(i) != two.getDimSize(i))
154       return {};
155     join.push_back(one.getDimSize(i));
156   }
157   return RankedTensorType::get(join, one.getElementType());
158 }
159 
160 namespace {
161 
162 /// Replaces chains of two tensor.cast operations by a single tensor.cast
163 /// operation if doing so does not remove runtime constraints.
164 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
165   using OpRewritePattern<CastOp>::OpRewritePattern;
166 
167   LogicalResult matchAndRewrite(CastOp tensorCast,
168                                 PatternRewriter &rewriter) const final {
169     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
170 
171     if (!tensorCastOperand)
172       return failure();
173 
174     auto sourceType =
175         tensorCastOperand.getOperand().getType().cast<TensorType>();
176     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
177     auto resultType = tensorCast.getType().cast<TensorType>();
178 
179     // We can remove the intermediate cast if joining all three produces the
180     // same result as just joining the source and result shapes.
181     auto firstJoin =
182         joinShapes(joinShapes(sourceType, intermediateType), resultType);
183 
184     // The join might not exist if the cast sequence would fail at runtime.
185     if (!firstJoin)
186       return failure();
187 
188     // The newJoin always exists if the above join exists, it might just contain
189     // less information. If so, we cannot drop the intermediate cast, as doing
190     // so would remove runtime checks.
191     auto newJoin = joinShapes(sourceType, resultType);
192     if (firstJoin != newJoin)
193       return failure();
194 
195     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
196                                         tensorCastOperand.getOperand());
197     return success();
198   }
199 };
200 
201 } // namespace
202 
203 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
204                                          MLIRContext *context) {
205   results.add<ChainedTensorCast>(context);
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // DimOp
210 //===----------------------------------------------------------------------===//
211 
212 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
213                   int64_t index) {
214   auto loc = result.location;
215   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
216   build(builder, result, source, indexValue);
217 }
218 
219 Optional<int64_t> DimOp::getConstantIndex() {
220   if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
221     return constantOp.getValue().cast<IntegerAttr>().getInt();
222   return {};
223 }
224 
225 static LogicalResult verify(DimOp op) {
226   // Assume unknown index to be in range.
227   Optional<int64_t> index = op.getConstantIndex();
228   if (!index.hasValue())
229     return success();
230 
231   // Check that constant index is not knowingly out of range.
232   auto type = op.source().getType();
233   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
234     if (index.getValue() >= tensorType.getRank())
235       return op.emitOpError("index is out of range");
236   } else if (type.isa<UnrankedTensorType>()) {
237     // Assume index to be in range.
238   } else {
239     llvm_unreachable("expected operand with tensor type");
240   }
241   return success();
242 }
243 
244 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
245   // All forms of folding require a known index.
246   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
247   if (!index)
248     return {};
249 
250   // Folding for unranked types (UnrankedTensorType) is not supported.
251   auto tensorType = source().getType().dyn_cast<RankedTensorType>();
252   if (!tensorType)
253     return {};
254 
255   // Fold if the shape extent along the given index is known.
256   if (!tensorType.isDynamicDim(index.getInt())) {
257     Builder builder(getContext());
258     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
259   }
260 
261   Operation *definingOp = source().getDefiningOp();
262 
263   // Fold dim to the operand of tensor.generate.
264   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
265     auto resultType =
266         fromElements.getResult().getType().cast<RankedTensorType>();
267     // The case where the type encodes the size of the dimension is handled
268     // above.
269     assert(resultType.getShape()[index.getInt()] ==
270            RankedTensorType::kDynamicSize);
271 
272     // Find the operand of the fromElements that corresponds to this index.
273     auto dynExtents = fromElements.dynamicExtents().begin();
274     for (auto dim : resultType.getShape().take_front(index.getInt()))
275       if (dim == RankedTensorType::kDynamicSize)
276         dynExtents++;
277 
278     return Value{*dynExtents};
279   }
280 
281   // The size at the given index is now known to be a dynamic size.
282   unsigned unsignedIndex = index.getValue().getZExtValue();
283 
284   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
285     // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
286     // `resolve-shaped-type-result-dims` pass.
287     if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
288         sliceOp.isDynamicSize(unsignedIndex)) {
289       return {sliceOp.getDynamicSize(unsignedIndex)};
290     }
291   }
292 
293   // dim(cast) -> dim
294   if (succeeded(foldTensorCast(*this)))
295     return getResult();
296 
297   return {};
298 }
299 
300 namespace {
301 /// Fold dim of a cast into the dim of the source of the tensor cast.
302 struct DimOfCastOp : public OpRewritePattern<DimOp> {
303   using OpRewritePattern<DimOp>::OpRewritePattern;
304 
305   LogicalResult matchAndRewrite(DimOp dimOp,
306                                 PatternRewriter &rewriter) const override {
307     auto castOp = dimOp.source().getDefiningOp<CastOp>();
308     if (!castOp)
309       return failure();
310     Value newSource = castOp.getOperand();
311     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
312     return success();
313   }
314 };
315 } // end anonymous namespace.
316 
317 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
318                                         MLIRContext *context) {
319   results.add<DimOfCastOp>(context);
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // ExtractOp
324 //===----------------------------------------------------------------------===//
325 
326 static LogicalResult verify(ExtractOp op) {
327   // Verify the # indices match if we have a ranked type.
328   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
329     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
330       return op.emitOpError("incorrect number of indices for extract_element");
331 
332   return success();
333 }
334 
335 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
336   // The tensor operand must be a known constant.
337   Attribute tensor = operands.front();
338   if (!tensor)
339     return {};
340   // If this is a splat elements attribute, simply return the value. All of the
341   // elements of a splat attribute are the same.
342   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
343     return splatTensor.getSplatValue<Attribute>();
344 
345   // Otherwise, collect the constant indices into the tensor.
346   SmallVector<uint64_t, 8> indices;
347   for (Attribute indice : llvm::drop_begin(operands, 1)) {
348     if (!indice || !indice.isa<IntegerAttr>())
349       return {};
350     indices.push_back(indice.cast<IntegerAttr>().getInt());
351   }
352 
353   // If this is an elements attribute, query the value at the given indices.
354   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
355   if (elementsAttr && elementsAttr.isValidIndex(indices))
356     return elementsAttr.getValues<Attribute>()[indices];
357   return {};
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // FromElementsOp
362 //===----------------------------------------------------------------------===//
363 
364 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
365                            Type elementType, ValueRange elements) {
366   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
367                                         elementType);
368   result.addOperands(elements);
369   result.addTypes(resultTy);
370 }
371 
372 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
373                            ValueRange elements) {
374   assert(!elements.empty() && "expected at least one element");
375   build(builder, result, elements.front().getType(), elements);
376 }
377 
378 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
379   if (!llvm::is_contained(operands, nullptr))
380     return DenseElementsAttr::get(getType(), operands);
381   return {};
382 }
383 
384 namespace {
385 
386 // Canonicalizes the pattern of the form
387 //
388 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
389 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
390 //
391 // to just %element.
392 struct ExtractElementFromTensorFromElements
393     : public OpRewritePattern<tensor::ExtractOp> {
394   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
395 
396   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
397                                 PatternRewriter &rewriter) const final {
398     if (extract.indices().size() != 1)
399       return failure();
400 
401     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
402     if (tensorFromElements == nullptr)
403       return failure();
404 
405     APInt index;
406     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
407       return failure();
408     // Prevent out of bounds accesses. This can happen in invalid code that will
409     // never execute.
410     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
411         index.getSExtValue() < 0)
412       return failure();
413     rewriter.replaceOp(extract,
414                        tensorFromElements.getOperand(index.getZExtValue()));
415     return success();
416   }
417 };
418 
419 } // namespace
420 
421 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
422                                                  MLIRContext *context) {
423   results.add<ExtractElementFromTensorFromElements>(context);
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // InsertOp
428 //===----------------------------------------------------------------------===//
429 
430 static LogicalResult verify(InsertOp op) {
431   // Verify the # indices match if we have a ranked type.
432   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
433     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
434       return op.emitOpError("incorrect number of indices");
435   return success();
436 }
437 
438 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
439   Attribute scalar = operands[0];
440   Attribute dest = operands[1];
441   if (scalar && dest)
442     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
443       if (scalar == splatDest.getSplatValue<Attribute>())
444         return dest;
445   return {};
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // GenerateOp
450 //===----------------------------------------------------------------------===//
451 
452 static LogicalResult verify(GenerateOp op) {
453   // Ensure that the tensor type has as many dynamic dimensions as are specified
454   // by the operands.
455   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
456   if (op.getNumOperands() != resultTy.getNumDynamicDims())
457     return op.emitError("must have as many index operands as dynamic extents "
458                         "in the result type");
459 
460   // Ensure that region arguments span the index space.
461   if (!llvm::all_of(op.body().getArgumentTypes(),
462                     [](Type ty) { return ty.isIndex(); }))
463     return op.emitError("all body arguments must be index");
464   if (op.body().getNumArguments() != resultTy.getRank())
465     return op.emitError("must have one body argument per input dimension");
466 
467   // Ensure that the region yields an element of the right type.
468   auto yieldOp =
469       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
470   if (yieldOp.value().getType() != resultTy.getElementType())
471     return op.emitOpError(
472         "body must be terminated with a `yield` operation of the tensor "
473         "element type");
474 
475   return success();
476 }
477 
478 void GenerateOp::build(
479     OpBuilder &b, OperationState &result, Type resultTy,
480     ValueRange dynamicExtents,
481     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
482   build(b, result, resultTy, dynamicExtents);
483 
484   // Build and populate body.
485   OpBuilder::InsertionGuard guard(b);
486   Region *bodyRegion = result.regions.front().get();
487   auto rank = resultTy.cast<RankedTensorType>().getRank();
488   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
489   Block *bodyBlock =
490       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
491   bodyBuilder(b, result.location, bodyBlock->getArguments());
492 }
493 
494 namespace {
495 
496 /// Canonicalizes tensor.generate operations with a constant
497 /// operand into the equivalent operation with the operand expressed in the
498 /// result type, instead. We also insert a type cast to make sure that the
499 /// resulting IR is still well-typed.
500 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
501   using OpRewritePattern<GenerateOp>::OpRewritePattern;
502 
503   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
504                                 PatternRewriter &rewriter) const final {
505     auto resultType =
506         tensorFromElements.getResult().getType().cast<RankedTensorType>();
507 
508     if (resultType.hasStaticShape())
509       return failure();
510 
511     SmallVector<Value, 4> newOperands;
512     SmallVector<int64_t, 4> newShape;
513     auto operandsIt = tensorFromElements.dynamicExtents().begin();
514 
515     for (int64_t dim : resultType.getShape()) {
516       if (dim != RankedTensorType::kDynamicSize) {
517         newShape.push_back(dim);
518         continue;
519       }
520       APInt index;
521       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
522         newShape.push_back(RankedTensorType::kDynamicSize);
523         newOperands.push_back(*operandsIt++);
524         continue;
525       }
526       newShape.push_back(index.getSExtValue());
527       operandsIt++;
528     }
529 
530     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
531       return failure();
532 
533     auto loc = tensorFromElements.getLoc();
534     auto newOp = rewriter.create<GenerateOp>(
535         loc, RankedTensorType::get(newShape, resultType.getElementType()),
536         newOperands);
537     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
538                                 newOp.body().begin());
539     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
540                                                 newOp);
541     return success();
542   }
543 };
544 
545 /// Canonicalizes the pattern of the form
546 ///
547 /// %tensor = tensor.generate %x {
548 ///   ^bb0(%arg0: index):  // no predecessors
549 ///   <computation>
550 ///   yield %1 : index
551 /// } : tensor<?xindex>
552 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
553 ///
554 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
555 /// tensor.generate operation has no side-effects.
556 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
557   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
558 
559   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
560                                 PatternRewriter &rewriter) const final {
561     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
562     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
563       return failure();
564 
565     BlockAndValueMapping mapping;
566     Block *body = tensorFromElements.getBody();
567     mapping.map(body->getArguments(), extract.indices());
568     for (auto &op : body->without_terminator())
569       rewriter.clone(op, mapping);
570 
571     auto yield = cast<YieldOp>(body->getTerminator());
572 
573     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
574     return success();
575   }
576 };
577 
578 /// Canonicalizes the pattern of the form
579 ///
580 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
581 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
582 ///
583 /// to
584 ///
585 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
586 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
587   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
588 
589   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
590                                 PatternRewriter &rewriter) const final {
591     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
592     if (!tensorCast)
593       return failure();
594 
595     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
596                                                    extract.indices());
597     return success();
598   }
599 };
600 
601 } // namespace
602 
603 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
604                                              MLIRContext *context) {
605   // TODO: Move extract patterns to tensor::ExtractOp.
606   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
607               StaticTensorGenerate>(context);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // ReshapeOp
612 //===----------------------------------------------------------------------===//
613 
614 static int64_t GetNumElements(ShapedType type) {
615   int64_t numElements = 1;
616   for (auto dim : type.getShape())
617     numElements *= dim;
618   return numElements;
619 }
620 
621 static LogicalResult verify(ReshapeOp op) {
622   TensorType operandType = op.source().getType().cast<TensorType>();
623   TensorType resultType = op.result().getType().cast<TensorType>();
624 
625   if (operandType.getElementType() != resultType.getElementType())
626     return op.emitOpError("element types of source and destination tensor "
627                           "types should be the same");
628 
629   int64_t shapeSize =
630       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
631   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
632   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
633 
634   if (resultRankedType) {
635     if (operandRankedType && resultRankedType.hasStaticShape() &&
636         operandRankedType.hasStaticShape()) {
637       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
638         return op.emitOpError("source and destination tensor should have the "
639                               "same number of elements");
640     }
641     if (shapeSize == TensorType::kDynamicSize)
642       return op.emitOpError("cannot use shape operand with dynamic length to "
643                             "reshape to statically-ranked tensor type");
644     if (shapeSize != resultRankedType.getRank())
645       return op.emitOpError(
646           "length of shape operand differs from the result's tensor rank");
647   }
648   return success();
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // ExtractSliceOp
653 //===----------------------------------------------------------------------===//
654 
655 /// An extract_slice op result type can be fully inferred from the source type
656 /// and the static representation of offsets, sizes and strides. Special
657 /// sentinels encode the dynamic case.
658 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
659                                      ArrayRef<int64_t> leadingStaticOffsets,
660                                      ArrayRef<int64_t> leadingStaticSizes,
661                                      ArrayRef<int64_t> leadingStaticStrides) {
662   // An extract_slice op may specify only a leading subset of offset/sizes/
663   // strides in which case we complete with offset=0, sizes from memref type and
664   // strides=1.
665   unsigned rank = sourceRankedTensorType.getRank();
666   assert(leadingStaticSizes.size() <= rank &&
667          "unexpected leadingStaticSizes overflow");
668   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
669   unsigned numTrailingSizes = rank - staticSizes.size();
670   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
671                                       numTrailingSizes));
672   return RankedTensorType::get(staticSizes,
673                                sourceRankedTensorType.getElementType());
674 }
675 
676 Type ExtractSliceOp::inferResultType(
677     RankedTensorType sourceRankedTensorType,
678     ArrayRef<OpFoldResult> leadingStaticOffsets,
679     ArrayRef<OpFoldResult> leadingStaticSizes,
680     ArrayRef<OpFoldResult> leadingStaticStrides) {
681   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
682   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
683   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
684                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
685   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
686                              ShapedType::kDynamicSize);
687   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
688                              staticStrides, ShapedType::kDynamicStrideOrOffset);
689   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
690                                          staticSizes, staticStrides);
691 }
692 
693 /// An extract_slice op result type can be fully inferred from the source type
694 /// and the static representation of offsets, sizes and strides. Special
695 /// sentinels encode the dynamic case.
696 Type ExtractSliceOp::inferRankReducedResultType(
697     unsigned resultRank, RankedTensorType sourceRankedTensorType,
698     ArrayRef<int64_t> leadingStaticOffsets,
699     ArrayRef<int64_t> leadingStaticSizes,
700     ArrayRef<int64_t> leadingStaticStrides) {
701   auto inferredType =
702       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
703                       leadingStaticSizes, leadingStaticStrides)
704           .cast<RankedTensorType>();
705   int rankDiff = inferredType.getRank() - resultRank;
706   if (rankDiff > 0) {
707     auto shape = inferredType.getShape();
708     llvm::SmallDenseSet<unsigned> dimsToProject;
709     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
710     SmallVector<int64_t> projectedShape;
711     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
712       if (!dimsToProject.contains(pos))
713         projectedShape.push_back(shape[pos]);
714     inferredType =
715         RankedTensorType::get(projectedShape, inferredType.getElementType());
716   }
717   return inferredType;
718 }
719 
720 Type ExtractSliceOp::inferRankReducedResultType(
721     unsigned resultRank, RankedTensorType sourceRankedTensorType,
722     ArrayRef<OpFoldResult> leadingStaticOffsets,
723     ArrayRef<OpFoldResult> leadingStaticSizes,
724     ArrayRef<OpFoldResult> leadingStaticStrides) {
725   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
726   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
727   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
728                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
729   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
730                              ShapedType::kDynamicSize);
731   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
732                              staticStrides, ShapedType::kDynamicStrideOrOffset);
733   return ExtractSliceOp::inferRankReducedResultType(
734       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
735       staticStrides);
736 }
737 
738 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
739 /// result type. If the type passed is nullptr, it is inferred.
740 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
741                            RankedTensorType resultType, Value source,
742                            ArrayRef<OpFoldResult> offsets,
743                            ArrayRef<OpFoldResult> sizes,
744                            ArrayRef<OpFoldResult> strides,
745                            ArrayRef<NamedAttribute> attrs) {
746   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
747   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
748   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
749                              ShapedType::kDynamicStrideOrOffset);
750   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
751                              ShapedType::kDynamicSize);
752   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
753                              ShapedType::kDynamicStrideOrOffset);
754   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
755   // Structuring implementation this way avoids duplication between builders.
756   if (!resultType) {
757     resultType =
758         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
759                                         staticSizes, staticStrides)
760             .cast<RankedTensorType>();
761   }
762   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
763         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
764         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
765   result.addAttributes(attrs);
766 }
767 
768 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
769 /// result type.
770 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
771                            ArrayRef<OpFoldResult> offsets,
772                            ArrayRef<OpFoldResult> sizes,
773                            ArrayRef<OpFoldResult> strides,
774                            ArrayRef<NamedAttribute> attrs) {
775   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
776 }
777 
778 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
779 /// type passed is nullptr, it is inferred.
780 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
781                            RankedTensorType resultType, Value source,
782                            ValueRange offsets, ValueRange sizes,
783                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
784   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
785       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
786   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
787       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
788   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
789       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
790   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
791 }
792 
793 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
794 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
795                            ValueRange offsets, ValueRange sizes,
796                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
797   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
798 }
799 
800 enum SliceVerificationResult {
801   Success,
802   RankTooLarge,
803   SizeMismatch,
804   ElemTypeMismatch,
805 };
806 
807 /// Checks if `original` Type type can be rank reduced to `reduced` type.
808 /// This function is slight variant of `is subsequence` algorithm where
809 /// not matching dimension must be 1.
810 static SliceVerificationResult
811 isRankReducedType(Type originalType, Type candidateReducedType,
812                   std::string *errMsg = nullptr) {
813   if (originalType == candidateReducedType)
814     return SliceVerificationResult::Success;
815   if (!originalType.isa<RankedTensorType>())
816     return SliceVerificationResult::Success;
817   if (originalType.isa<RankedTensorType>() &&
818       !candidateReducedType.isa<RankedTensorType>())
819     return SliceVerificationResult::Success;
820 
821   ShapedType originalShapedType = originalType.cast<ShapedType>();
822   ShapedType candidateReducedShapedType =
823       candidateReducedType.cast<ShapedType>();
824 
825   // Rank and size logic is valid for all ShapedTypes.
826   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
827   ArrayRef<int64_t> candidateReducedShape =
828       candidateReducedShapedType.getShape();
829   unsigned originalRank = originalShape.size(),
830            candidateReducedRank = candidateReducedShape.size();
831   if (candidateReducedRank > originalRank)
832     return SliceVerificationResult::RankTooLarge;
833 
834   auto optionalUnusedDimsMask =
835       computeRankReductionMask(originalShape, candidateReducedShape);
836 
837   // Sizes cannot be matched in case empty vector is returned.
838   if (!optionalUnusedDimsMask.hasValue())
839     return SliceVerificationResult::SizeMismatch;
840 
841   if (originalShapedType.getElementType() !=
842       candidateReducedShapedType.getElementType())
843     return SliceVerificationResult::ElemTypeMismatch;
844 
845   // We are done for the tensor case.
846   if (originalType.isa<RankedTensorType>())
847     return SliceVerificationResult::Success;
848 
849   return SliceVerificationResult::Success;
850 }
851 
852 template <typename OpTy>
853 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
854                                           OpTy op, Type expectedType,
855                                           StringRef errMsg = "") {
856   auto memrefType = expectedType.cast<ShapedType>();
857   switch (result) {
858   case SliceVerificationResult::Success:
859     return success();
860   case SliceVerificationResult::RankTooLarge:
861     return op.emitError("expected result rank to be smaller or equal to ")
862            << "the source rank. " << errMsg;
863   case SliceVerificationResult::SizeMismatch:
864     return op.emitError("expected result type to be ")
865            << expectedType
866            << " or a rank-reduced version. (mismatch of result sizes) "
867            << errMsg;
868   case SliceVerificationResult::ElemTypeMismatch:
869     return op.emitError("expected result element type to be ")
870            << memrefType.getElementType() << errMsg;
871   }
872   llvm_unreachable("unexpected extract_slice op verification result");
873 }
874 
875 /// Verifier for ExtractSliceOp.
876 static LogicalResult verify(ExtractSliceOp op) {
877   // Verify result type against inferred type.
878   auto expectedType = ExtractSliceOp::inferResultType(
879       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
880       extractFromI64ArrayAttr(op.static_sizes()),
881       extractFromI64ArrayAttr(op.static_strides()));
882   auto result = isRankReducedType(expectedType, op.getType());
883   return produceSliceErrorMsg(result, op, expectedType);
884 }
885 
886 /// Infer the canonical type of the result of an extract_slice op. Returns a
887 /// type with rank `resultRank` that is either the rank of the rank-reduced
888 /// type, or the non-rank-reduced type.
889 static RankedTensorType
890 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
891                             ArrayRef<OpFoldResult> mixedOffsets,
892                             ArrayRef<OpFoldResult> mixedSizes,
893                             ArrayRef<OpFoldResult> mixedStrides) {
894   auto resultType =
895       ExtractSliceOp::inferRankReducedResultType(
896           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
897           .cast<RankedTensorType>();
898   if (resultType.getRank() != resultRank) {
899     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
900                                                  mixedSizes, mixedStrides)
901                      .cast<RankedTensorType>();
902   }
903   return resultType;
904 }
905 
906 llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
907   llvm::SmallDenseSet<unsigned> droppedDims;
908   ArrayRef<int64_t> resultShape = getType().getShape();
909   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
910   unsigned shapePos = 0;
911   for (auto size : enumerate(mixedSizes)) {
912     Optional<int64_t> sizeVal = getConstantIntValue(size.value());
913     // If the size is not 1, or if the current matched dimension of the result
914     // is the same static shape as the size value (which is 1), then the
915     // dimension is preserved.
916     if (!sizeVal || sizeVal.getValue() != 1 ||
917         (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
918       shapePos++;
919       continue;
920     }
921     droppedDims.insert(size.index());
922   }
923   return droppedDims;
924 }
925 
926 LogicalResult ExtractSliceOp::reifyResultShapes(
927     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
928   reifiedReturnShapes.resize(1);
929   reifiedReturnShapes[0].reserve(getType().getRank());
930   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
931   llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
932   Location loc = getLoc();
933   for (auto size : enumerate(mixedSizes)) {
934     if (droppedDims.count(size.index()))
935       continue;
936     if (auto attr = size.value().dyn_cast<Attribute>()) {
937       reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
938           loc, attr.cast<IntegerAttr>().getInt()));
939       continue;
940     }
941     reifiedReturnShapes[0].push_back(size.value().get<Value>());
942   }
943   return success();
944 }
945 
946 namespace {
947 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
948 /// This essentially pushes memref_cast past its consuming slice when
949 /// `canFoldIntoConsumerOp` is true.
950 ///
951 /// Example:
952 /// ```
953 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
954 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
955 ///   tensor<3x4xf32>
956 /// ```
957 /// is rewritten into:
958 /// ```
959 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
960 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
961 /// ```
962 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
963 public:
964   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
965 
966   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
967                                 PatternRewriter &rewriter) const override {
968     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
969     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
970           return matchPattern(operand, matchConstantIndex());
971         }))
972       return failure();
973 
974     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
975     if (!castOp)
976       return failure();
977 
978     if (!canFoldIntoConsumerOp(castOp))
979       return failure();
980 
981     /// Deduce the type of the result to use for the canonicalized operation.
982     RankedTensorType resultType = getCanonicalSliceResultType(
983         sliceOp.getType().getRank(), sliceOp.getSourceType(),
984         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
985         sliceOp.getMixedStrides());
986     Value newSlice = rewriter.create<ExtractSliceOp>(
987         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
988         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
989         sliceOp.static_sizes(), sliceOp.static_strides());
990     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
991                                                 newSlice);
992     return success();
993   }
994 };
995 } // namespace
996 
997 /// Return the canonical type of the result of an extract_slice op.
998 struct SliceReturnTypeCanonicalizer {
999   RankedTensorType operator()(ExtractSliceOp op,
1000                               ArrayRef<OpFoldResult> mixedOffsets,
1001                               ArrayRef<OpFoldResult> mixedSizes,
1002                               ArrayRef<OpFoldResult> mixedStrides) {
1003     return getCanonicalSliceResultType(op.getType().getRank(),
1004                                        op.getSourceType(), mixedOffsets,
1005                                        mixedSizes, mixedStrides);
1006   }
1007 };
1008 
1009 /// A canonicalizer wrapper to replace ExtractSliceOps.
1010 struct SliceCanonicalizer {
1011   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1012                   ExtractSliceOp newOp) {
1013     Value replacement = newOp.getResult();
1014     if (replacement.getType() != op.getType())
1015       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1016                                                     replacement);
1017     rewriter.replaceOp(op, replacement);
1018   }
1019 };
1020 
1021 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1022                                                  MLIRContext *context) {
1023   results.add<
1024       OpWithOffsetSizesAndStridesConstantArgumentFolder<
1025           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
1026       ExtractSliceOpCastFolder>(context);
1027 }
1028 
1029 //
1030 static LogicalResult
1031 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1032                                            ShapedType shapedType) {
1033   OpBuilder b(op.getContext());
1034   for (OpFoldResult ofr : op.getMixedOffsets())
1035     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1036       return failure();
1037   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1038   // is appropriate.
1039   auto shape = shapedType.getShape();
1040   for (auto it : llvm::zip(op.getMixedSizes(), shape))
1041     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1042       return failure();
1043   for (OpFoldResult ofr : op.getMixedStrides())
1044     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1045       return failure();
1046   return success();
1047 }
1048 
1049 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1050 /// we can return the InsertSliceOp's source directly.
1051 // TODO: This only checks the immediate producer; extend to go up the
1052 // insert/extract chain if the slices are disjoint.
1053 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
1054   auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
1055 
1056   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1057   if (insertOp && insertOp.source().getType() == extractOp.getType() &&
1058       insertOp.isSameAs(extractOp, isSame))
1059     return insertOp.source();
1060 
1061   return {};
1062 }
1063 
1064 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
1065   if (getSourceType() == getType() &&
1066       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1067     return this->source();
1068   if (Value slice = foldExtractAfterInsertSlice(*this))
1069     return slice;
1070   return OpFoldResult();
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // InsertSliceOp
1075 //===----------------------------------------------------------------------===//
1076 
1077 // Build a InsertSliceOp with mixed static and dynamic entries.
1078 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1079                           Value dest, ArrayRef<OpFoldResult> offsets,
1080                           ArrayRef<OpFoldResult> sizes,
1081                           ArrayRef<OpFoldResult> strides,
1082                           ArrayRef<NamedAttribute> attrs) {
1083   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1084   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1085   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1086                              ShapedType::kDynamicStrideOrOffset);
1087   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1088                              ShapedType::kDynamicSize);
1089   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1090                              ShapedType::kDynamicStrideOrOffset);
1091   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1092         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1093         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1094   result.addAttributes(attrs);
1095 }
1096 
1097 // Build a InsertSliceOp with dynamic entries.
1098 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1099                           Value dest, ValueRange offsets, ValueRange sizes,
1100                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1101   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1102       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1103   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1104       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1105   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1106       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1107   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1108 }
1109 
1110 /// If we have two consecutive InsertSliceOp writing to the same slice, we
1111 /// can mutate the second InsertSliceOp's destination to the first one's.
1112 ///
1113 /// Example:
1114 ///
1115 /// ```mlir
1116 ///   %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1117 ///   %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1118 /// ```
1119 ///
1120 /// folds into:
1121 ///
1122 /// ```mlir
1123 ///   %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1124 /// ```
1125 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
1126   auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
1127 
1128   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1129   if (!prevInsertOp ||
1130       prevInsertOp.source().getType() != insertOp.source().getType() ||
1131       !prevInsertOp.isSameAs(insertOp, isSame))
1132     return failure();
1133 
1134   insertOp.destMutable().assign(prevInsertOp.dest());
1135   return success();
1136 }
1137 
1138 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1139   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1140       getSourceType() == getType() &&
1141       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1142     return this->source();
1143   if (succeeded(foldInsertAfterInsertSlice(*this)))
1144     return getResult();
1145   return OpFoldResult();
1146 }
1147 
1148 LogicalResult InsertSliceOp::reifyResultShapes(
1149     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1150   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1151   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1152     reifiedReturnShapes[0][dim] =
1153         builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
1154   }
1155   return success();
1156 }
1157 
1158 namespace {
1159 /// Pattern to rewrite a insert_slice op with constant arguments.
1160 class InsertSliceOpConstantArgumentFolder final
1161     : public OpRewritePattern<InsertSliceOp> {
1162 public:
1163   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1164 
1165   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1166                                 PatternRewriter &rewriter) const override {
1167     // No constant operand, just return.
1168     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1169           return matchPattern(operand, matchConstantIndex());
1170         }))
1171       return failure();
1172 
1173     // At least one of offsets/sizes/strides is a new constant.
1174     // Form the new list of operands and constant attributes from the
1175     // existing.
1176     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1177     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1178     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1179     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1180     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1181     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1182 
1183     // Create the new op in canonical form.
1184     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1185         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
1186         mixedOffsets, mixedSizes, mixedStrides);
1187     return success();
1188   }
1189 };
1190 
1191 /// Fold tensor_casts with insert_slice operations. If the source or destination
1192 /// tensor is a tensor_cast that removes static type information, the cast is
1193 /// folded into the insert_slice operation. E.g.:
1194 ///
1195 /// ```mlir
1196 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1197 ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1198 /// ```
1199 ///
1200 /// folds into:
1201 ///
1202 /// ```mlir
1203 ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1204 /// ```
1205 ///
1206 /// Note: When folding a cast on the destination tensor, the result of the
1207 /// insert_slice operation is casted to ensure that the type of the result did
1208 /// not change.
1209 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1210   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1211 
1212   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1213                                 PatternRewriter &rewriter) const override {
1214     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1215           return matchPattern(operand, matchConstantIndex());
1216         }))
1217       return failure();
1218 
1219     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1220       auto castOp = v.getDefiningOp<tensor::CastOp>();
1221       if (!castOp || !canFoldIntoConsumerOp(castOp))
1222         return llvm::None;
1223       return castOp.source();
1224     };
1225     Optional<Value> sourceCastSource =
1226         getSourceOfCastOp(insertSliceOp.source());
1227     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1228     if (!sourceCastSource && !destCastSource)
1229       return failure();
1230 
1231     Value replacement = rewriter.create<InsertSliceOp>(
1232         insertSliceOp.getLoc(),
1233         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1234         (destCastSource ? *destCastSource : insertSliceOp.dest()),
1235         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1236         insertSliceOp.getMixedStrides());
1237 
1238     if (replacement.getType() != insertSliceOp.getType()) {
1239       replacement = rewriter.create<tensor::CastOp>(
1240           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1241     }
1242     rewriter.replaceOp(insertSliceOp, replacement);
1243     return success();
1244   }
1245 };
1246 
1247 /// If additional static type information can be deduced from a insert_slice's
1248 /// size operands, insert an explicit cast of the op's source operand. This
1249 /// enables other canonicalization patterns that are matching for tensor_cast
1250 /// ops such as `ForOpTensorCastFolder` in SCF.
1251 ///
1252 /// Example:
1253 ///
1254 /// ```mlir
1255 ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1256 ///       : tensor<?x?xf32> into ...
1257 /// ```
1258 ///
1259 /// folds into:
1260 ///
1261 /// ```mlir
1262 ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1263 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1264 ///       : tensor<64x64xf32> into ...
1265 /// ```
1266 struct InsertSliceOpSourceCastInserter final
1267     : public OpRewritePattern<InsertSliceOp> {
1268   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1269 
1270   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1271                                 PatternRewriter &rewriter) const override {
1272     RankedTensorType srcType = insertSliceOp.getSourceType();
1273     if (srcType.getRank() != insertSliceOp.getType().getRank())
1274       return failure();
1275     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1276                                      srcType.getShape().end());
1277     for (int64_t i = 0; i < srcType.getRank(); ++i) {
1278       if (Optional<int64_t> constInt =
1279               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1280         newSrcShape[i] = *constInt;
1281     }
1282 
1283     RankedTensorType newSrcType =
1284         RankedTensorType::get(newSrcShape, srcType.getElementType());
1285     if (srcType == newSrcType ||
1286         !preservesStaticInformation(srcType, newSrcType) ||
1287         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1288       return failure();
1289 
1290     // newSrcType is:
1291     //   1) Different from srcType.
1292     //   2) "More static" than srcType.
1293     //   3) Cast-compatible with srcType.
1294     // Insert the cast.
1295     Value cast = rewriter.create<tensor::CastOp>(
1296         insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
1297     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1298         insertSliceOp, cast, insertSliceOp.dest(),
1299         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1300         insertSliceOp.getMixedStrides());
1301     return success();
1302   }
1303 };
1304 } // namespace
1305 
1306 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1307                                                 MLIRContext *context) {
1308   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1309               InsertSliceOpSourceCastInserter>(context);
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // TableGen'd op method definitions
1314 //===----------------------------------------------------------------------===//
1315 
1316 #define GET_OP_CLASSES
1317 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1318