1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Utils/StaticValueUtils.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 /// Materialize a single constant operation from a given attribute value with
23 /// the desired resultant type.
24 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
25                                               Attribute value, Type type,
26                                               Location loc) {
27   return builder.create<mlir::ConstantOp>(loc, type, value);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // CastOp
32 //===----------------------------------------------------------------------===//
33 
34 /// Returns true if `target` is a ranked tensor type that preserves static
35 /// information available in the `source` ranked tensor type.
36 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
37   auto sourceType = source.dyn_cast<RankedTensorType>();
38   auto targetType = target.dyn_cast<RankedTensorType>();
39 
40   // Requires RankedTensorType.
41   if (!sourceType || !targetType)
42     return false;
43 
44   // Requires same elemental type.
45   if (sourceType.getElementType() != targetType.getElementType())
46     return false;
47 
48   // Requires same rank.
49   if (sourceType.getRank() != targetType.getRank())
50     return false;
51 
52   // If cast is towards more static sizes along any dimension, don't fold.
53   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
54     if (!ShapedType::isDynamic(std::get<0>(t)) &&
55         ShapedType::isDynamic(std::get<1>(t)))
56       return false;
57   }
58 
59   return true;
60 }
61 
62 /// Determines whether tensor::CastOp casts to a more dynamic version of the
63 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
64 /// implement canonicalization patterns for ops in different dialects that may
65 /// consume the results of tensor.cast operations. Such foldable tensor.cast
66 /// operations are typically inserted as `slice` ops and are canonicalized,
67 /// to preserve the type compatibility of their uses.
68 ///
69 /// Returns true when all conditions are met:
70 /// 1. source and result are ranked tensors with same element type and rank.
71 /// 2. the tensor type has more static information than the result
72 ///
73 /// Example:
74 /// ```mlir
75 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
76 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
77 /// ```
78 ///
79 /// folds into:
80 ///
81 /// ```mlir
82 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
83 /// ```
84 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
85   if (!castOp)
86     return false;
87 
88   // Can fold if the source of cast has at least as much static information as
89   // its results.
90   return preservesStaticInformation(castOp.getType(),
91                                     castOp.source().getType());
92 }
93 
94 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
95 /// that can be folded.
96 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
97   bool folded = false;
98   for (OpOperand &operand : op->getOpOperands()) {
99     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
100     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
101       operand.set(castOp.getOperand());
102       folded = true;
103     }
104   }
105   return success(folded);
106 }
107 
108 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
109   if (inputs.size() != 1 || outputs.size() != 1)
110     return false;
111   Type a = inputs.front(), b = outputs.front();
112   auto aT = a.dyn_cast<TensorType>();
113   auto bT = b.dyn_cast<TensorType>();
114   if (!aT || !bT)
115     return false;
116 
117   if (aT.getElementType() != bT.getElementType())
118     return false;
119 
120   return succeeded(verifyCompatibleShape(aT, bT));
121 }
122 
123 /// Compute a TensorType that has the joined shape knowledge of the two
124 /// given TensorTypes. The element types need to match.
125 static TensorType joinShapes(TensorType one, TensorType two) {
126   assert(one.getElementType() == two.getElementType());
127 
128   if (!one.hasRank())
129     return two;
130   if (!two.hasRank())
131     return one;
132 
133   int64_t rank = one.getRank();
134   if (rank != two.getRank())
135     return {};
136 
137   SmallVector<int64_t, 4> join;
138   join.reserve(rank);
139   for (int64_t i = 0; i < rank; ++i) {
140     if (one.isDynamicDim(i)) {
141       join.push_back(two.getDimSize(i));
142       continue;
143     }
144     if (two.isDynamicDim(i)) {
145       join.push_back(one.getDimSize(i));
146       continue;
147     }
148     if (one.getDimSize(i) != two.getDimSize(i))
149       return {};
150     join.push_back(one.getDimSize(i));
151   }
152   return RankedTensorType::get(join, one.getElementType());
153 }
154 
155 namespace {
156 
157 /// Replaces chains of two tensor.cast operations by a single tensor.cast
158 /// operation if doing so does not remove runtime constraints.
159 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
160   using OpRewritePattern<CastOp>::OpRewritePattern;
161 
162   LogicalResult matchAndRewrite(CastOp tensorCast,
163                                 PatternRewriter &rewriter) const final {
164     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
165 
166     if (!tensorCastOperand)
167       return failure();
168 
169     auto sourceType =
170         tensorCastOperand.getOperand().getType().cast<TensorType>();
171     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
172     auto resultType = tensorCast.getType().cast<TensorType>();
173 
174     // We can remove the intermediate cast if joining all three produces the
175     // same result as just joining the source and result shapes.
176     auto firstJoin =
177         joinShapes(joinShapes(sourceType, intermediateType), resultType);
178 
179     // The join might not exist if the cast sequence would fail at runtime.
180     if (!firstJoin)
181       return failure();
182 
183     // The newJoin always exists if the above join exists, it might just contain
184     // less information. If so, we cannot drop the intermediate cast, as doing
185     // so would remove runtime checks.
186     auto newJoin = joinShapes(sourceType, resultType);
187     if (firstJoin != newJoin)
188       return failure();
189 
190     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
191                                         tensorCastOperand.getOperand());
192     return success();
193   }
194 };
195 
196 } // namespace
197 
198 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
199                                          MLIRContext *context) {
200   results.add<ChainedTensorCast>(context);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // DimOp
205 //===----------------------------------------------------------------------===//
206 
207 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
208                   int64_t index) {
209   auto loc = result.location;
210   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
211   build(builder, result, source, indexValue);
212 }
213 
214 Optional<int64_t> DimOp::getConstantIndex() {
215   if (auto constantOp = index().getDefiningOp<ConstantOp>())
216     return constantOp.getValue().cast<IntegerAttr>().getInt();
217   return {};
218 }
219 
220 static LogicalResult verify(DimOp op) {
221   // Assume unknown index to be in range.
222   Optional<int64_t> index = op.getConstantIndex();
223   if (!index.hasValue())
224     return success();
225 
226   // Check that constant index is not knowingly out of range.
227   auto type = op.source().getType();
228   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
229     if (index.getValue() >= tensorType.getRank())
230       return op.emitOpError("index is out of range");
231   } else if (type.isa<UnrankedTensorType>()) {
232     // Assume index to be in range.
233   } else {
234     llvm_unreachable("expected operand with tensor type");
235   }
236   return success();
237 }
238 
239 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
240   // All forms of folding require a known index.
241   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
242   if (!index)
243     return {};
244 
245   // Folding for unranked types (UnrankedTensorType) is not supported.
246   auto tensorType = source().getType().dyn_cast<RankedTensorType>();
247   if (!tensorType)
248     return {};
249 
250   // Fold if the shape extent along the given index is known.
251   if (!tensorType.isDynamicDim(index.getInt())) {
252     Builder builder(getContext());
253     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
254   }
255 
256   Operation *definingOp = source().getDefiningOp();
257 
258   // Fold dim to the operand of tensor.generate.
259   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
260     auto resultType =
261         fromElements.getResult().getType().cast<RankedTensorType>();
262     // The case where the type encodes the size of the dimension is handled
263     // above.
264     assert(resultType.getShape()[index.getInt()] ==
265            RankedTensorType::kDynamicSize);
266 
267     // Find the operand of the fromElements that corresponds to this index.
268     auto dynExtents = fromElements.dynamicExtents().begin();
269     for (auto dim : resultType.getShape().take_front(index.getInt()))
270       if (dim == RankedTensorType::kDynamicSize)
271         dynExtents++;
272 
273     return Value{*dynExtents};
274   }
275 
276   // The size at the given index is now known to be a dynamic size.
277   unsigned unsignedIndex = index.getValue().getZExtValue();
278 
279   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
280     assert(sliceOp.isDynamicSize(unsignedIndex) &&
281            "Expected dynamic slice size");
282     return sliceOp.getDynamicSize(unsignedIndex);
283   }
284 
285   // dim(cast) -> dim
286   if (succeeded(foldTensorCast(*this)))
287     return getResult();
288 
289   return {};
290 }
291 
292 namespace {
293 /// Fold dim of a cast into the dim of the source of the tensor cast.
294 struct DimOfCastOp : public OpRewritePattern<DimOp> {
295   using OpRewritePattern<DimOp>::OpRewritePattern;
296 
297   LogicalResult matchAndRewrite(DimOp dimOp,
298                                 PatternRewriter &rewriter) const override {
299     auto castOp = dimOp.source().getDefiningOp<CastOp>();
300     if (!castOp)
301       return failure();
302     Value newSource = castOp.getOperand();
303     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
304     return success();
305   }
306 };
307 } // end anonymous namespace.
308 
309 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
310                                         MLIRContext *context) {
311   results.add<DimOfCastOp>(context);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // ExtractOp
316 //===----------------------------------------------------------------------===//
317 
318 static LogicalResult verify(ExtractOp op) {
319   // Verify the # indices match if we have a ranked type.
320   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
321     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
322       return op.emitOpError("incorrect number of indices for extract_element");
323 
324   return success();
325 }
326 
327 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
328   // The tensor operand must be a known constant.
329   Attribute tensor = operands.front();
330   if (!tensor)
331     return {};
332   // If this is a splat elements attribute, simply return the value. All of the
333   // elements of a splat attribute are the same.
334   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
335     return splatTensor.getSplatValue();
336 
337   // Otherwise, collect the constant indices into the tensor.
338   SmallVector<uint64_t, 8> indices;
339   for (Attribute indice : llvm::drop_begin(operands, 1)) {
340     if (!indice || !indice.isa<IntegerAttr>())
341       return {};
342     indices.push_back(indice.cast<IntegerAttr>().getInt());
343   }
344 
345   // If this is an elements attribute, query the value at the given indices.
346   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
347   if (elementsAttr && elementsAttr.isValidIndex(indices))
348     return elementsAttr.getValue(indices);
349   return {};
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // FromElementsOp
354 //===----------------------------------------------------------------------===//
355 
356 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
357                            Type elementType, ValueRange elements) {
358   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
359                                         elementType);
360   result.addOperands(elements);
361   result.addTypes(resultTy);
362 }
363 
364 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
365                            ValueRange elements) {
366   assert(!elements.empty() && "expected at least one element");
367   build(builder, result, elements.front().getType(), elements);
368 }
369 
370 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
371   if (!llvm::is_contained(operands, nullptr))
372     return DenseElementsAttr::get(getType(), operands);
373   return {};
374 }
375 
376 namespace {
377 
378 // Canonicalizes the pattern of the form
379 //
380 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
381 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
382 //
383 // to just %element.
384 struct ExtractElementFromTensorFromElements
385     : public OpRewritePattern<tensor::ExtractOp> {
386   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
387 
388   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
389                                 PatternRewriter &rewriter) const final {
390     if (extract.indices().size() != 1)
391       return failure();
392 
393     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
394     if (tensorFromElements == nullptr)
395       return failure();
396 
397     APInt index;
398     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
399       return failure();
400     // Prevent out of bounds accesses. This can happen in invalid code that will
401     // never execute.
402     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
403         index.getSExtValue() < 0)
404       return failure();
405     rewriter.replaceOp(extract,
406                        tensorFromElements.getOperand(index.getZExtValue()));
407     return success();
408   }
409 };
410 
411 } // namespace
412 
413 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
414                                                  MLIRContext *context) {
415   results.add<ExtractElementFromTensorFromElements>(context);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // InsertOp
420 //===----------------------------------------------------------------------===//
421 
422 static LogicalResult verify(InsertOp op) {
423   // Verify the # indices match if we have a ranked type.
424   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
425     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
426       return op.emitOpError("incorrect number of indices");
427   return success();
428 }
429 
430 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
431   Attribute scalar = operands[0];
432   Attribute dest = operands[1];
433   if (scalar && dest)
434     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
435       if (scalar == splatDest.getSplatValue())
436         return dest;
437   return {};
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // GenerateOp
442 //===----------------------------------------------------------------------===//
443 
444 static LogicalResult verify(GenerateOp op) {
445   // Ensure that the tensor type has as many dynamic dimensions as are specified
446   // by the operands.
447   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
448   if (op.getNumOperands() != resultTy.getNumDynamicDims())
449     return op.emitError("must have as many index operands as dynamic extents "
450                         "in the result type");
451 
452   // Ensure that region arguments span the index space.
453   if (!llvm::all_of(op.body().getArgumentTypes(),
454                     [](Type ty) { return ty.isIndex(); }))
455     return op.emitError("all body arguments must be index");
456   if (op.body().getNumArguments() != resultTy.getRank())
457     return op.emitError("must have one body argument per input dimension");
458 
459   // Ensure that the region yields an element of the right type.
460   auto yieldOp =
461       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
462   if (yieldOp.value().getType() != resultTy.getElementType())
463     return op.emitOpError(
464         "body must be terminated with a `yield` operation of the tensor "
465         "element type");
466 
467   return success();
468 }
469 
470 void GenerateOp::build(
471     OpBuilder &b, OperationState &result, Type resultTy,
472     ValueRange dynamicExtents,
473     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
474   build(b, result, resultTy, dynamicExtents);
475 
476   // Build and populate body.
477   OpBuilder::InsertionGuard guard(b);
478   Region *bodyRegion = result.regions.front().get();
479   auto rank = resultTy.cast<RankedTensorType>().getRank();
480   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
481   Block *bodyBlock =
482       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
483   bodyBuilder(b, result.location, bodyBlock->getArguments());
484 }
485 
486 namespace {
487 
488 /// Canonicalizes tensor.generate operations with a constant
489 /// operand into the equivalent operation with the operand expressed in the
490 /// result type, instead. We also insert a type cast to make sure that the
491 /// resulting IR is still well-typed.
492 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
493   using OpRewritePattern<GenerateOp>::OpRewritePattern;
494 
495   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
496                                 PatternRewriter &rewriter) const final {
497     auto resultType =
498         tensorFromElements.getResult().getType().cast<RankedTensorType>();
499 
500     if (resultType.hasStaticShape())
501       return failure();
502 
503     SmallVector<Value, 4> newOperands;
504     SmallVector<int64_t, 4> newShape;
505     auto operandsIt = tensorFromElements.dynamicExtents().begin();
506 
507     for (int64_t dim : resultType.getShape()) {
508       if (dim != RankedTensorType::kDynamicSize) {
509         newShape.push_back(dim);
510         continue;
511       }
512       APInt index;
513       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
514         newShape.push_back(RankedTensorType::kDynamicSize);
515         newOperands.push_back(*operandsIt++);
516         continue;
517       }
518       newShape.push_back(index.getSExtValue());
519       operandsIt++;
520     }
521 
522     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
523       return failure();
524 
525     auto loc = tensorFromElements.getLoc();
526     auto newOp = rewriter.create<GenerateOp>(
527         loc, RankedTensorType::get(newShape, resultType.getElementType()),
528         newOperands);
529     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
530                                 newOp.body().begin());
531     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
532                                                 newOp);
533     return success();
534   }
535 };
536 
537 /// Canonicalizes the pattern of the form
538 ///
539 /// %tensor = tensor.generate %x {
540 ///   ^bb0(%arg0: index):  // no predecessors
541 ///   <computation>
542 ///   yield %1 : index
543 /// } : tensor<?xindex>
544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
545 ///
546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
547 /// tensor.generate operation has no side-effects.
548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
549   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
550 
551   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
552                                 PatternRewriter &rewriter) const final {
553     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
554     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
555       return failure();
556 
557     BlockAndValueMapping mapping;
558     Block *body = tensorFromElements.getBody();
559     mapping.map(body->getArguments(), extract.indices());
560     for (auto &op : body->without_terminator())
561       rewriter.clone(op, mapping);
562 
563     auto yield = cast<YieldOp>(body->getTerminator());
564 
565     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
566     return success();
567   }
568 };
569 
570 /// Canonicalizes the pattern of the form
571 ///
572 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
573 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
574 ///
575 /// to
576 ///
577 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
578 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
579   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
580 
581   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
582                                 PatternRewriter &rewriter) const final {
583     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
584     if (!tensorCast)
585       return failure();
586 
587     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
588                                                    extract.indices());
589     return success();
590   }
591 };
592 
593 } // namespace
594 
595 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
596                                              MLIRContext *context) {
597   // TODO: Move extract patterns to tensor::ExtractOp.
598   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
599               StaticTensorGenerate>(context);
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // ReshapeOp
604 //===----------------------------------------------------------------------===//
605 
606 static int64_t GetNumElements(ShapedType type) {
607   int64_t numElements = 1;
608   for (auto dim : type.getShape())
609     numElements *= dim;
610   return numElements;
611 }
612 
613 static LogicalResult verify(ReshapeOp op) {
614   TensorType operandType = op.source().getType().cast<TensorType>();
615   TensorType resultType = op.result().getType().cast<TensorType>();
616 
617   if (operandType.getElementType() != resultType.getElementType())
618     return op.emitOpError("element types of source and destination tensor "
619                           "types should be the same");
620 
621   int64_t shapeSize =
622       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
623   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
624   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
625 
626   if (resultRankedType) {
627     if (operandRankedType && resultRankedType.hasStaticShape() &&
628         operandRankedType.hasStaticShape()) {
629       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
630         return op.emitOpError("source and destination tensor should have the "
631                               "same number of elements");
632     }
633     if (shapeSize == TensorType::kDynamicSize)
634       return op.emitOpError("cannot use shape operand with dynamic length to "
635                             "reshape to statically-ranked tensor type");
636     if (shapeSize != resultRankedType.getRank())
637       return op.emitOpError(
638           "length of shape operand differs from the result's tensor rank");
639   }
640   return success();
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // ExtractSliceOp
645 //===----------------------------------------------------------------------===//
646 
647 /// An extract_slice op result type can be fully inferred from the source type
648 /// and the static representation of offsets, sizes and strides. Special
649 /// sentinels encode the dynamic case.
650 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
651                                      ArrayRef<int64_t> leadingStaticOffsets,
652                                      ArrayRef<int64_t> leadingStaticSizes,
653                                      ArrayRef<int64_t> leadingStaticStrides) {
654   // An extract_slice op may specify only a leading subset of offset/sizes/
655   // strides in which case we complete with offset=0, sizes from memref type and
656   // strides=1.
657   unsigned rank = sourceRankedTensorType.getRank();
658   assert(leadingStaticSizes.size() <= rank &&
659          "unexpected leadingStaticSizes overflow");
660   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
661   unsigned numTrailingSizes = rank - staticSizes.size();
662   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
663                                       numTrailingSizes));
664   return RankedTensorType::get(staticSizes,
665                                sourceRankedTensorType.getElementType());
666 }
667 
668 Type ExtractSliceOp::inferResultType(
669     RankedTensorType sourceRankedTensorType,
670     ArrayRef<OpFoldResult> leadingStaticOffsets,
671     ArrayRef<OpFoldResult> leadingStaticSizes,
672     ArrayRef<OpFoldResult> leadingStaticStrides) {
673   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
674   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
675   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
676                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
677   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
678                              ShapedType::kDynamicSize);
679   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
680                              staticStrides, ShapedType::kDynamicStrideOrOffset);
681   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
682                                          staticSizes, staticStrides);
683 }
684 
685 /// An extract_slice op result type can be fully inferred from the source type
686 /// and the static representation of offsets, sizes and strides. Special
687 /// sentinels encode the dynamic case.
688 Type ExtractSliceOp::inferRankReducedResultType(
689     unsigned resultRank, RankedTensorType sourceRankedTensorType,
690     ArrayRef<int64_t> leadingStaticOffsets,
691     ArrayRef<int64_t> leadingStaticSizes,
692     ArrayRef<int64_t> leadingStaticStrides) {
693   auto inferredType =
694       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
695                       leadingStaticSizes, leadingStaticStrides)
696           .cast<RankedTensorType>();
697   int rankDiff = inferredType.getRank() - resultRank;
698   if (rankDiff > 0) {
699     auto shape = inferredType.getShape();
700     llvm::SmallDenseSet<unsigned> dimsToProject;
701     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
702     SmallVector<int64_t> projectedShape;
703     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
704       if (!dimsToProject.contains(pos))
705         projectedShape.push_back(shape[pos]);
706     inferredType =
707         RankedTensorType::get(projectedShape, inferredType.getElementType());
708   }
709   return inferredType;
710 }
711 
712 Type ExtractSliceOp::inferRankReducedResultType(
713     unsigned resultRank, RankedTensorType sourceRankedTensorType,
714     ArrayRef<OpFoldResult> leadingStaticOffsets,
715     ArrayRef<OpFoldResult> leadingStaticSizes,
716     ArrayRef<OpFoldResult> leadingStaticStrides) {
717   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
718   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
719   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
720                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
721   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
722                              ShapedType::kDynamicSize);
723   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
724                              staticStrides, ShapedType::kDynamicStrideOrOffset);
725   return ExtractSliceOp::inferRankReducedResultType(
726       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
727       staticStrides);
728 }
729 
730 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
731 /// result type. If the type passed is nullptr, it is inferred.
732 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
733                            RankedTensorType resultType, Value source,
734                            ArrayRef<OpFoldResult> offsets,
735                            ArrayRef<OpFoldResult> sizes,
736                            ArrayRef<OpFoldResult> strides,
737                            ArrayRef<NamedAttribute> attrs) {
738   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
739   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
740   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
741                              ShapedType::kDynamicStrideOrOffset);
742   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
743                              ShapedType::kDynamicSize);
744   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
745                              ShapedType::kDynamicStrideOrOffset);
746   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
747   // Structuring implementation this way avoids duplication between builders.
748   if (!resultType) {
749     resultType =
750         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
751                                         staticSizes, staticStrides)
752             .cast<RankedTensorType>();
753   }
754   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
755         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
756         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
757   result.addAttributes(attrs);
758 }
759 
760 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
761 /// result type.
762 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
763                            ArrayRef<OpFoldResult> offsets,
764                            ArrayRef<OpFoldResult> sizes,
765                            ArrayRef<OpFoldResult> strides,
766                            ArrayRef<NamedAttribute> attrs) {
767   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
768 }
769 
770 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
771 /// type passed is nullptr, it is inferred.
772 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
773                            RankedTensorType resultType, Value source,
774                            ValueRange offsets, ValueRange sizes,
775                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
776   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
777       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
778   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
779       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
780   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
781       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
782   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
783 }
784 
785 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
786 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
787                            ValueRange offsets, ValueRange sizes,
788                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
789   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
790 }
791 
792 enum SliceVerificationResult {
793   Success,
794   RankTooLarge,
795   SizeMismatch,
796   ElemTypeMismatch,
797 };
798 
799 /// Checks if `original` Type type can be rank reduced to `reduced` type.
800 /// This function is slight variant of `is subsequence` algorithm where
801 /// not matching dimension must be 1.
802 static SliceVerificationResult
803 isRankReducedType(Type originalType, Type candidateReducedType,
804                   std::string *errMsg = nullptr) {
805   if (originalType == candidateReducedType)
806     return SliceVerificationResult::Success;
807   if (!originalType.isa<RankedTensorType>())
808     return SliceVerificationResult::Success;
809   if (originalType.isa<RankedTensorType>() &&
810       !candidateReducedType.isa<RankedTensorType>())
811     return SliceVerificationResult::Success;
812 
813   ShapedType originalShapedType = originalType.cast<ShapedType>();
814   ShapedType candidateReducedShapedType =
815       candidateReducedType.cast<ShapedType>();
816 
817   // Rank and size logic is valid for all ShapedTypes.
818   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
819   ArrayRef<int64_t> candidateReducedShape =
820       candidateReducedShapedType.getShape();
821   unsigned originalRank = originalShape.size(),
822            candidateReducedRank = candidateReducedShape.size();
823   if (candidateReducedRank > originalRank)
824     return SliceVerificationResult::RankTooLarge;
825 
826   auto optionalUnusedDimsMask =
827       computeRankReductionMask(originalShape, candidateReducedShape);
828 
829   // Sizes cannot be matched in case empty vector is returned.
830   if (!optionalUnusedDimsMask.hasValue())
831     return SliceVerificationResult::SizeMismatch;
832 
833   if (originalShapedType.getElementType() !=
834       candidateReducedShapedType.getElementType())
835     return SliceVerificationResult::ElemTypeMismatch;
836 
837   // We are done for the tensor case.
838   if (originalType.isa<RankedTensorType>())
839     return SliceVerificationResult::Success;
840 
841   return SliceVerificationResult::Success;
842 }
843 
844 template <typename OpTy>
845 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
846                                           OpTy op, Type expectedType,
847                                           StringRef errMsg = "") {
848   auto memrefType = expectedType.cast<ShapedType>();
849   switch (result) {
850   case SliceVerificationResult::Success:
851     return success();
852   case SliceVerificationResult::RankTooLarge:
853     return op.emitError("expected result rank to be smaller or equal to ")
854            << "the source rank. " << errMsg;
855   case SliceVerificationResult::SizeMismatch:
856     return op.emitError("expected result type to be ")
857            << expectedType
858            << " or a rank-reduced version. (mismatch of result sizes) "
859            << errMsg;
860   case SliceVerificationResult::ElemTypeMismatch:
861     return op.emitError("expected result element type to be ")
862            << memrefType.getElementType() << errMsg;
863   }
864   llvm_unreachable("unexpected extract_slice op verification result");
865 }
866 
867 /// Verifier for ExtractSliceOp.
868 static LogicalResult verify(ExtractSliceOp op) {
869   // Verify result type against inferred type.
870   auto expectedType = ExtractSliceOp::inferResultType(
871       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
872       extractFromI64ArrayAttr(op.static_sizes()),
873       extractFromI64ArrayAttr(op.static_strides()));
874   auto result = isRankReducedType(expectedType, op.getType());
875   return produceSliceErrorMsg(result, op, expectedType);
876 }
877 
878 /// Infer the canonical type of the result of an extract_slice op. Returns a
879 /// type with rank `resultRank` that is either the rank of the rank-reduced
880 /// type, or the non-rank-reduced type.
881 static RankedTensorType
882 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
883                             ArrayRef<OpFoldResult> mixedOffsets,
884                             ArrayRef<OpFoldResult> mixedSizes,
885                             ArrayRef<OpFoldResult> mixedStrides) {
886   auto resultType =
887       ExtractSliceOp::inferRankReducedResultType(
888           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
889           .cast<RankedTensorType>();
890   if (resultType.getRank() != resultRank) {
891     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
892                                                  mixedSizes, mixedStrides)
893                      .cast<RankedTensorType>();
894   }
895   return resultType;
896 }
897 
898 namespace {
899 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
900 /// This essentially pushes memref_cast past its consuming slice when
901 /// `canFoldIntoConsumerOp` is true.
902 ///
903 /// Example:
904 /// ```
905 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
906 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
907 ///   tensor<3x4xf32>
908 /// ```
909 /// is rewritten into:
910 /// ```
911 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
912 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
913 /// ```
914 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
915 public:
916   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
917 
918   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
919                                 PatternRewriter &rewriter) const override {
920     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
921     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
922           return matchPattern(operand, matchConstantIndex());
923         }))
924       return failure();
925 
926     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
927     if (!castOp)
928       return failure();
929 
930     if (!canFoldIntoConsumerOp(castOp))
931       return failure();
932 
933     /// Deduce the type of the result to use for the canonicalized operation.
934     RankedTensorType resultType = getCanonicalSliceResultType(
935         sliceOp.getType().getRank(), sliceOp.getSourceType(),
936         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
937         sliceOp.getMixedStrides());
938     Value newSlice = rewriter.create<ExtractSliceOp>(
939         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
940         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
941         sliceOp.static_sizes(), sliceOp.static_strides());
942     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
943                                                 newSlice);
944     return success();
945   }
946 };
947 } // namespace
948 
949 /// Return the canonical type of the result of an extract_slice op.
950 struct SliceReturnTypeCanonicalizer {
951   RankedTensorType operator()(ExtractSliceOp op,
952                               ArrayRef<OpFoldResult> mixedOffsets,
953                               ArrayRef<OpFoldResult> mixedSizes,
954                               ArrayRef<OpFoldResult> mixedStrides) {
955     return getCanonicalSliceResultType(op.getType().getRank(),
956                                        op.getSourceType(), mixedOffsets,
957                                        mixedSizes, mixedStrides);
958   }
959 };
960 
961 /// A canonicalizer wrapper to replace ExtractSliceOps.
962 struct SliceCanonicalizer {
963   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
964                   ExtractSliceOp newOp) {
965     Value replacement = newOp.getResult();
966     if (replacement.getType() != op.getType())
967       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
968                                                     replacement);
969     rewriter.replaceOp(op, replacement);
970   }
971 };
972 
973 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
974                                                  MLIRContext *context) {
975   results.add<
976       OpWithOffsetSizesAndStridesConstantArgumentFolder<
977           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
978       ExtractSliceOpCastFolder>(context);
979 }
980 
981 //
982 static LogicalResult
983 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
984                                            ShapedType shapedType) {
985   OpBuilder b(op.getContext());
986   for (OpFoldResult ofr : op.getMixedOffsets())
987     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
988       return failure();
989   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
990   // is appropriate.
991   auto shape = shapedType.getShape();
992   for (auto it : llvm::zip(op.getMixedSizes(), shape))
993     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
994       return failure();
995   for (OpFoldResult ofr : op.getMixedStrides())
996     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
997       return failure();
998   return success();
999 }
1000 
1001 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
1002   if (getSourceType() == getType() &&
1003       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1004     return this->source();
1005   return OpFoldResult();
1006 }
1007 
1008 //===----------------------------------------------------------------------===//
1009 // InsertSliceOp
1010 //===----------------------------------------------------------------------===//
1011 
1012 // Build a InsertSliceOp with mixed static and dynamic entries.
1013 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1014                           Value dest, ArrayRef<OpFoldResult> offsets,
1015                           ArrayRef<OpFoldResult> sizes,
1016                           ArrayRef<OpFoldResult> strides,
1017                           ArrayRef<NamedAttribute> attrs) {
1018   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1019   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1020   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1021                              ShapedType::kDynamicStrideOrOffset);
1022   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1023                              ShapedType::kDynamicSize);
1024   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1025                              ShapedType::kDynamicStrideOrOffset);
1026   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1027         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1028         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1029   result.addAttributes(attrs);
1030 }
1031 
1032 // Build a InsertSliceOp with dynamic entries.
1033 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1034                           Value dest, ValueRange offsets, ValueRange sizes,
1035                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1036   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1037       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1038   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1039       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1040   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1041       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1042   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1043 }
1044 
1045 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1046   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1047       getSourceType() == getType() &&
1048       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1049     return this->source();
1050   return OpFoldResult();
1051 }
1052 
1053 LogicalResult InsertSliceOp::reifyResultShapes(
1054     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1055   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1056   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1057     reifiedReturnShapes[0][dim] =
1058         builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
1059   }
1060   return success();
1061 }
1062 
1063 namespace {
1064 /// Pattern to rewrite a insert_slice op with constant arguments.
1065 class InsertSliceOpConstantArgumentFolder final
1066     : public OpRewritePattern<InsertSliceOp> {
1067 public:
1068   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1069 
1070   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1071                                 PatternRewriter &rewriter) const override {
1072     // No constant operand, just return.
1073     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1074           return matchPattern(operand, matchConstantIndex());
1075         }))
1076       return failure();
1077 
1078     // At least one of offsets/sizes/strides is a new constant.
1079     // Form the new list of operands and constant attributes from the
1080     // existing.
1081     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1082     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1083     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1084     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1085     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1086     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1087 
1088     // Create the new op in canonical form.
1089     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1090         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
1091         mixedOffsets, mixedSizes, mixedStrides);
1092     return success();
1093   }
1094 };
1095 
1096 /// Fold tensor_casts with insert_slice operations. If the source or destination
1097 /// tensor is a tensor_cast that removes static type information, the cast is
1098 /// folded into the insert_slice operation. E.g.:
1099 ///
1100 /// ```mlir
1101 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1102 ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1103 /// ```
1104 ///
1105 /// folds into:
1106 ///
1107 /// ```mlir
1108 ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1109 /// ```
1110 ///
1111 /// Note: When folding a cast on the destination tensor, the result of the
1112 /// insert_slice operation is casted to ensure that the type of the result did
1113 /// not change.
1114 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1115   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1116 
1117   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1118                                 PatternRewriter &rewriter) const override {
1119     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1120           return matchPattern(operand, matchConstantIndex());
1121         }))
1122       return failure();
1123 
1124     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1125       auto castOp = v.getDefiningOp<tensor::CastOp>();
1126       if (!castOp || !canFoldIntoConsumerOp(castOp))
1127         return llvm::None;
1128       return castOp.source();
1129     };
1130     Optional<Value> sourceCastSource =
1131         getSourceOfCastOp(insertSliceOp.source());
1132     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1133     if (!sourceCastSource && !destCastSource)
1134       return failure();
1135 
1136     Value replacement = rewriter.create<InsertSliceOp>(
1137         insertSliceOp.getLoc(),
1138         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1139         (destCastSource ? *destCastSource : insertSliceOp.dest()),
1140         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1141         insertSliceOp.getMixedStrides());
1142 
1143     if (replacement.getType() != insertSliceOp.getType()) {
1144       replacement = rewriter.create<tensor::CastOp>(
1145           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1146     }
1147     rewriter.replaceOp(insertSliceOp, replacement);
1148     return success();
1149   }
1150 };
1151 
1152 /// If additional static type information can be deduced from a insert_slice's
1153 /// size operands, insert an explicit cast of the op's source operand. This
1154 /// enables other canonicalization patterns that are matching for tensor_cast
1155 /// ops such as `ForOpTensorCastFolder` in SCF.
1156 ///
1157 /// Example:
1158 ///
1159 /// ```mlir
1160 ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1161 ///       : tensor<?x?xf32> into ...
1162 /// ```
1163 ///
1164 /// folds into:
1165 ///
1166 /// ```mlir
1167 ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1168 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1169 ///       : tensor<64x64xf32> into ...
1170 /// ```
1171 struct InsertSliceOpSourceCastInserter final
1172     : public OpRewritePattern<InsertSliceOp> {
1173   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1174 
1175   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1176                                 PatternRewriter &rewriter) const override {
1177     RankedTensorType srcType = insertSliceOp.getSourceType();
1178     if (srcType.getRank() != insertSliceOp.getType().getRank())
1179       return failure();
1180     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1181                                      srcType.getShape().end());
1182     for (int64_t i = 0; i < srcType.getRank(); ++i) {
1183       if (Optional<int64_t> constInt =
1184               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1185         newSrcShape[i] = *constInt;
1186     }
1187     RankedTensorType newSrcType =
1188         RankedTensorType::get(newSrcShape, srcType.getElementType());
1189     if (srcType == newSrcType)
1190       return failure();
1191 
1192     // srcType and newSrcType are different. Insert a cast.
1193     Value cast = rewriter.create<tensor::CastOp>(
1194         insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
1195     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1196         insertSliceOp, cast, insertSliceOp.dest(),
1197         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1198         insertSliceOp.getMixedStrides());
1199     return success();
1200   }
1201 };
1202 } // namespace
1203 
1204 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1205                                                 MLIRContext *context) {
1206   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1207               InsertSliceOpSourceCastInserter>(context);
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // TableGen'd op method definitions
1212 //===----------------------------------------------------------------------===//
1213 
1214 #define GET_OP_CLASSES
1215 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1216