1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Utils/StaticValueUtils.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 /// Materialize a single constant operation from a given attribute value with
23 /// the desired resultant type.
24 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
25                                               Attribute value, Type type,
26                                               Location loc) {
27   return builder.create<mlir::ConstantOp>(loc, type, value);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // CastOp
32 //===----------------------------------------------------------------------===//
33 
34 /// Determines whether tensor::CastOp casts to a more dynamic version of the
35 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
36 /// implement canonicalization patterns for ops in different dialects that may
37 /// consume the results of tensor.cast operations. Such foldable tensor.cast
38 /// operations are typically inserted as `slice` ops and are canonicalized,
39 /// to preserve the type compatibility of their uses.
40 ///
41 /// Returns true when all conditions are met:
42 /// 1. source and result are ranked tensors with same element type and rank.
43 /// 2. the tensor type has more static information than the result
44 ///
45 /// Example:
46 /// ```mlir
47 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
48 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
49 /// ```
50 ///
51 /// folds into:
52 ///
53 /// ```mlir
54 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
55 /// ```
56 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
57   if (!castOp)
58     return false;
59 
60   RankedTensorType sourceType =
61       castOp.source().getType().dyn_cast<RankedTensorType>();
62   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
63 
64   // Requires RankedTensorType.
65   if (!sourceType || !resultType)
66     return false;
67 
68   // Requires same elemental type.
69   if (sourceType.getElementType() != resultType.getElementType())
70     return false;
71 
72   // Requires same rank.
73   if (sourceType.getRank() != resultType.getRank())
74     return false;
75 
76   // If cast is towards more static sizes along any dimension, don't fold.
77   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
78     if (ShapedType::isDynamic(std::get<0>(t)) &&
79         !ShapedType::isDynamic(std::get<1>(t)))
80       return false;
81   }
82 
83   return true;
84 }
85 
86 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
87 /// that can be folded.
88 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
89   bool folded = false;
90   for (OpOperand &operand : op->getOpOperands()) {
91     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
92     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
93       operand.set(castOp.getOperand());
94       folded = true;
95     }
96   }
97   return success(folded);
98 }
99 
100 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
101   if (inputs.size() != 1 || outputs.size() != 1)
102     return false;
103   Type a = inputs.front(), b = outputs.front();
104   auto aT = a.dyn_cast<TensorType>();
105   auto bT = b.dyn_cast<TensorType>();
106   if (!aT || !bT)
107     return false;
108 
109   if (aT.getElementType() != bT.getElementType())
110     return false;
111 
112   return succeeded(verifyCompatibleShape(aT, bT));
113 }
114 
115 /// Compute a TensorType that has the joined shape knowledge of the two
116 /// given TensorTypes. The element types need to match.
117 static TensorType joinShapes(TensorType one, TensorType two) {
118   assert(one.getElementType() == two.getElementType());
119 
120   if (!one.hasRank())
121     return two;
122   if (!two.hasRank())
123     return one;
124 
125   int64_t rank = one.getRank();
126   if (rank != two.getRank())
127     return {};
128 
129   SmallVector<int64_t, 4> join;
130   join.reserve(rank);
131   for (int64_t i = 0; i < rank; ++i) {
132     if (one.isDynamicDim(i)) {
133       join.push_back(two.getDimSize(i));
134       continue;
135     }
136     if (two.isDynamicDim(i)) {
137       join.push_back(one.getDimSize(i));
138       continue;
139     }
140     if (one.getDimSize(i) != two.getDimSize(i))
141       return {};
142     join.push_back(one.getDimSize(i));
143   }
144   return RankedTensorType::get(join, one.getElementType());
145 }
146 
147 namespace {
148 
149 /// Replaces chains of two tensor.cast operations by a single tensor.cast
150 /// operation if doing so does not remove runtime constraints.
151 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
152   using OpRewritePattern<CastOp>::OpRewritePattern;
153 
154   LogicalResult matchAndRewrite(CastOp tensorCast,
155                                 PatternRewriter &rewriter) const final {
156     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
157 
158     if (!tensorCastOperand)
159       return failure();
160 
161     auto sourceType =
162         tensorCastOperand.getOperand().getType().cast<TensorType>();
163     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
164     auto resultType = tensorCast.getType().cast<TensorType>();
165 
166     // We can remove the intermediate cast if joining all three produces the
167     // same result as just joining the source and result shapes.
168     auto firstJoin =
169         joinShapes(joinShapes(sourceType, intermediateType), resultType);
170 
171     // The join might not exist if the cast sequence would fail at runtime.
172     if (!firstJoin)
173       return failure();
174 
175     // The newJoin always exists if the above join exists, it might just contain
176     // less information. If so, we cannot drop the intermediate cast, as doing
177     // so would remove runtime checks.
178     auto newJoin = joinShapes(sourceType, resultType);
179     if (firstJoin != newJoin)
180       return failure();
181 
182     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
183                                         tensorCastOperand.getOperand());
184     return success();
185   }
186 };
187 
188 } // namespace
189 
190 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
191                                          MLIRContext *context) {
192   results.add<ChainedTensorCast>(context);
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // DimOp
197 //===----------------------------------------------------------------------===//
198 
199 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
200                   int64_t index) {
201   auto loc = result.location;
202   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
203   build(builder, result, source, indexValue);
204 }
205 
206 Optional<int64_t> DimOp::getConstantIndex() {
207   if (auto constantOp = index().getDefiningOp<ConstantOp>())
208     return constantOp.getValue().cast<IntegerAttr>().getInt();
209   return {};
210 }
211 
212 static LogicalResult verify(DimOp op) {
213   // Assume unknown index to be in range.
214   Optional<int64_t> index = op.getConstantIndex();
215   if (!index.hasValue())
216     return success();
217 
218   // Check that constant index is not knowingly out of range.
219   auto type = op.source().getType();
220   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
221     if (index.getValue() >= tensorType.getRank())
222       return op.emitOpError("index is out of range");
223   } else if (type.isa<UnrankedTensorType>()) {
224     // Assume index to be in range.
225   } else {
226     llvm_unreachable("expected operand with tensor type");
227   }
228   return success();
229 }
230 
231 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
232   // All forms of folding require a known index.
233   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
234   if (!index)
235     return {};
236 
237   // Folding for unranked types (UnrankedTensorType) is not supported.
238   auto tensorType = source().getType().dyn_cast<RankedTensorType>();
239   if (!tensorType)
240     return {};
241 
242   // Fold if the shape extent along the given index is known.
243   if (!tensorType.isDynamicDim(index.getInt())) {
244     Builder builder(getContext());
245     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
246   }
247 
248   Operation *definingOp = source().getDefiningOp();
249 
250   // Fold dim to the operand of tensor.generate.
251   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
252     auto resultType =
253         fromElements.getResult().getType().cast<RankedTensorType>();
254     // The case where the type encodes the size of the dimension is handled
255     // above.
256     assert(resultType.getShape()[index.getInt()] ==
257            RankedTensorType::kDynamicSize);
258 
259     // Find the operand of the fromElements that corresponds to this index.
260     auto dynExtents = fromElements.dynamicExtents().begin();
261     for (auto dim : resultType.getShape().take_front(index.getInt()))
262       if (dim == RankedTensorType::kDynamicSize)
263         dynExtents++;
264 
265     return Value{*dynExtents};
266   }
267 
268   // The size at the given index is now known to be a dynamic size.
269   unsigned unsignedIndex = index.getValue().getZExtValue();
270 
271   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
272     assert(sliceOp.isDynamicSize(unsignedIndex) &&
273            "Expected dynamic slice size");
274     return sliceOp.getDynamicSize(unsignedIndex);
275   }
276 
277   // dim(cast) -> dim
278   if (succeeded(foldTensorCast(*this)))
279     return getResult();
280 
281   return {};
282 }
283 
284 namespace {
285 /// Fold dim of a cast into the dim of the source of the tensor cast.
286 struct DimOfCastOp : public OpRewritePattern<DimOp> {
287   using OpRewritePattern<DimOp>::OpRewritePattern;
288 
289   LogicalResult matchAndRewrite(DimOp dimOp,
290                                 PatternRewriter &rewriter) const override {
291     auto castOp = dimOp.source().getDefiningOp<CastOp>();
292     if (!castOp)
293       return failure();
294     Value newSource = castOp.getOperand();
295     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
296     return success();
297   }
298 };
299 } // end anonymous namespace.
300 
301 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
302                                         MLIRContext *context) {
303   results.add<DimOfCastOp>(context);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // ExtractOp
308 //===----------------------------------------------------------------------===//
309 
310 static LogicalResult verify(ExtractOp op) {
311   // Verify the # indices match if we have a ranked type.
312   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
313     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
314       return op.emitOpError("incorrect number of indices for extract_element");
315 
316   return success();
317 }
318 
319 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
320   // The tensor operand must be a known constant.
321   Attribute tensor = operands.front();
322   if (!tensor)
323     return {};
324   // If this is a splat elements attribute, simply return the value. All of the
325   // elements of a splat attribute are the same.
326   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
327     return splatTensor.getSplatValue();
328 
329   // Otherwise, collect the constant indices into the tensor.
330   SmallVector<uint64_t, 8> indices;
331   for (Attribute indice : llvm::drop_begin(operands, 1)) {
332     if (!indice || !indice.isa<IntegerAttr>())
333       return {};
334     indices.push_back(indice.cast<IntegerAttr>().getInt());
335   }
336 
337   // If this is an elements attribute, query the value at the given indices.
338   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
339   if (elementsAttr && elementsAttr.isValidIndex(indices))
340     return elementsAttr.getValue(indices);
341   return {};
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // FromElementsOp
346 //===----------------------------------------------------------------------===//
347 
348 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
349                            Type elementType, ValueRange elements) {
350   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
351                                         elementType);
352   result.addOperands(elements);
353   result.addTypes(resultTy);
354 }
355 
356 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
357                            ValueRange elements) {
358   assert(!elements.empty() && "expected at least one element");
359   build(builder, result, elements.front().getType(), elements);
360 }
361 
362 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
363   if (!llvm::is_contained(operands, nullptr))
364     return DenseElementsAttr::get(getType(), operands);
365   return {};
366 }
367 
368 namespace {
369 
370 // Canonicalizes the pattern of the form
371 //
372 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
373 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
374 //
375 // to just %element.
376 struct ExtractElementFromTensorFromElements
377     : public OpRewritePattern<tensor::ExtractOp> {
378   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
379 
380   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
381                                 PatternRewriter &rewriter) const final {
382     if (extract.indices().size() != 1)
383       return failure();
384 
385     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
386     if (tensorFromElements == nullptr)
387       return failure();
388 
389     APInt index;
390     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
391       return failure();
392     // Prevent out of bounds accesses. This can happen in invalid code that will
393     // never execute.
394     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
395         index.getSExtValue() < 0)
396       return failure();
397     rewriter.replaceOp(extract,
398                        tensorFromElements.getOperand(index.getZExtValue()));
399     return success();
400   }
401 };
402 
403 } // namespace
404 
405 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
406                                                  MLIRContext *context) {
407   results.add<ExtractElementFromTensorFromElements>(context);
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // InsertOp
412 //===----------------------------------------------------------------------===//
413 
414 static LogicalResult verify(InsertOp op) {
415   // Verify the # indices match if we have a ranked type.
416   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
417     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
418       return op.emitOpError("incorrect number of indices");
419   return success();
420 }
421 
422 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
423   Attribute scalar = operands[0];
424   Attribute dest = operands[1];
425   if (scalar && dest)
426     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
427       if (scalar == splatDest.getSplatValue())
428         return dest;
429   return {};
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // GenerateOp
434 //===----------------------------------------------------------------------===//
435 
436 static LogicalResult verify(GenerateOp op) {
437   // Ensure that the tensor type has as many dynamic dimensions as are specified
438   // by the operands.
439   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
440   if (op.getNumOperands() != resultTy.getNumDynamicDims())
441     return op.emitError("must have as many index operands as dynamic extents "
442                         "in the result type");
443 
444   // Ensure that region arguments span the index space.
445   if (!llvm::all_of(op.body().getArgumentTypes(),
446                     [](Type ty) { return ty.isIndex(); }))
447     return op.emitError("all body arguments must be index");
448   if (op.body().getNumArguments() != resultTy.getRank())
449     return op.emitError("must have one body argument per input dimension");
450 
451   // Ensure that the region yields an element of the right type.
452   auto yieldOp =
453       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
454   if (yieldOp.value().getType() != resultTy.getElementType())
455     return op.emitOpError(
456         "body must be terminated with a `yield` operation of the tensor "
457         "element type");
458 
459   return success();
460 }
461 
462 void GenerateOp::build(
463     OpBuilder &b, OperationState &result, Type resultTy,
464     ValueRange dynamicExtents,
465     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
466   build(b, result, resultTy, dynamicExtents);
467 
468   // Build and populate body.
469   OpBuilder::InsertionGuard guard(b);
470   Region *bodyRegion = result.regions.front().get();
471   auto rank = resultTy.cast<RankedTensorType>().getRank();
472   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
473   Block *bodyBlock =
474       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
475   bodyBuilder(b, result.location, bodyBlock->getArguments());
476 }
477 
478 namespace {
479 
480 /// Canonicalizes tensor.generate operations with a constant
481 /// operand into the equivalent operation with the operand expressed in the
482 /// result type, instead. We also insert a type cast to make sure that the
483 /// resulting IR is still well-typed.
484 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
485   using OpRewritePattern<GenerateOp>::OpRewritePattern;
486 
487   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
488                                 PatternRewriter &rewriter) const final {
489     auto resultType =
490         tensorFromElements.getResult().getType().cast<RankedTensorType>();
491 
492     if (resultType.hasStaticShape())
493       return failure();
494 
495     SmallVector<Value, 4> newOperands;
496     SmallVector<int64_t, 4> newShape;
497     auto operandsIt = tensorFromElements.dynamicExtents().begin();
498 
499     for (int64_t dim : resultType.getShape()) {
500       if (dim != RankedTensorType::kDynamicSize) {
501         newShape.push_back(dim);
502         continue;
503       }
504       APInt index;
505       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
506         newShape.push_back(RankedTensorType::kDynamicSize);
507         newOperands.push_back(*operandsIt++);
508         continue;
509       }
510       newShape.push_back(index.getSExtValue());
511       operandsIt++;
512     }
513 
514     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
515       return failure();
516 
517     auto loc = tensorFromElements.getLoc();
518     auto newOp = rewriter.create<GenerateOp>(
519         loc, RankedTensorType::get(newShape, resultType.getElementType()),
520         newOperands);
521     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
522                                 newOp.body().begin());
523     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
524                                                 newOp);
525     return success();
526   }
527 };
528 
529 /// Canonicalizes the pattern of the form
530 ///
531 /// %tensor = tensor.generate %x {
532 ///   ^bb0(%arg0: index):  // no predecessors
533 ///   <computation>
534 ///   yield %1 : index
535 /// } : tensor<?xindex>
536 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
537 ///
538 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
539 /// tensor.generate operation has no side-effects.
540 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
541   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
542 
543   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
544                                 PatternRewriter &rewriter) const final {
545     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
546     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
547       return failure();
548 
549     BlockAndValueMapping mapping;
550     Block *body = tensorFromElements.getBody();
551     mapping.map(body->getArguments(), extract.indices());
552     for (auto &op : body->without_terminator())
553       rewriter.clone(op, mapping);
554 
555     auto yield = cast<YieldOp>(body->getTerminator());
556 
557     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
558     return success();
559   }
560 };
561 
562 /// Canonicalizes the pattern of the form
563 ///
564 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
565 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
566 ///
567 /// to
568 ///
569 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
570 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
571   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
572 
573   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
574                                 PatternRewriter &rewriter) const final {
575     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
576     if (!tensorCast)
577       return failure();
578 
579     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
580                                                    extract.indices());
581     return success();
582   }
583 };
584 
585 } // namespace
586 
587 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
588                                              MLIRContext *context) {
589   // TODO: Move extract patterns to tensor::ExtractOp.
590   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
591               StaticTensorGenerate>(context);
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // ReshapeOp
596 //===----------------------------------------------------------------------===//
597 
598 static int64_t GetNumElements(ShapedType type) {
599   int64_t numElements = 1;
600   for (auto dim : type.getShape())
601     numElements *= dim;
602   return numElements;
603 }
604 
605 static LogicalResult verify(ReshapeOp op) {
606   TensorType operandType = op.source().getType().cast<TensorType>();
607   TensorType resultType = op.result().getType().cast<TensorType>();
608 
609   if (operandType.getElementType() != resultType.getElementType())
610     return op.emitOpError("element types of source and destination tensor "
611                           "types should be the same");
612 
613   int64_t shapeSize =
614       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
615   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
616   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
617 
618   if (resultRankedType) {
619     if (operandRankedType && resultRankedType.hasStaticShape() &&
620         operandRankedType.hasStaticShape()) {
621       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
622         return op.emitOpError("source and destination tensor should have the "
623                               "same number of elements");
624     }
625     if (shapeSize == TensorType::kDynamicSize)
626       return op.emitOpError("cannot use shape operand with dynamic length to "
627                             "reshape to statically-ranked tensor type");
628     if (shapeSize != resultRankedType.getRank())
629       return op.emitOpError(
630           "length of shape operand differs from the result's tensor rank");
631   }
632   return success();
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // ExtractSliceOp
637 //===----------------------------------------------------------------------===//
638 
639 /// An extract_slice op result type can be fully inferred from the source type
640 /// and the static representation of offsets, sizes and strides. Special
641 /// sentinels encode the dynamic case.
642 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
643                                      ArrayRef<int64_t> leadingStaticOffsets,
644                                      ArrayRef<int64_t> leadingStaticSizes,
645                                      ArrayRef<int64_t> leadingStaticStrides) {
646   // An extract_slice op may specify only a leading subset of offset/sizes/
647   // strides in which case we complete with offset=0, sizes from memref type and
648   // strides=1.
649   unsigned rank = sourceRankedTensorType.getRank();
650   assert(leadingStaticSizes.size() <= rank &&
651          "unexpected leadingStaticSizes overflow");
652   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
653   unsigned numTrailingSizes = rank - staticSizes.size();
654   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
655                                       numTrailingSizes));
656   return RankedTensorType::get(staticSizes,
657                                sourceRankedTensorType.getElementType());
658 }
659 
660 Type ExtractSliceOp::inferResultType(
661     RankedTensorType sourceRankedTensorType,
662     ArrayRef<OpFoldResult> leadingStaticOffsets,
663     ArrayRef<OpFoldResult> leadingStaticSizes,
664     ArrayRef<OpFoldResult> leadingStaticStrides) {
665   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
666   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
667   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
668                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
669   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
670                              ShapedType::kDynamicSize);
671   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
672                              staticStrides, ShapedType::kDynamicStrideOrOffset);
673   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
674                                          staticSizes, staticStrides);
675 }
676 
677 /// An extract_slice op result type can be fully inferred from the source type
678 /// and the static representation of offsets, sizes and strides. Special
679 /// sentinels encode the dynamic case.
680 Type ExtractSliceOp::inferRankReducedResultType(
681     unsigned resultRank, RankedTensorType sourceRankedTensorType,
682     ArrayRef<int64_t> leadingStaticOffsets,
683     ArrayRef<int64_t> leadingStaticSizes,
684     ArrayRef<int64_t> leadingStaticStrides) {
685   auto inferredType =
686       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
687                       leadingStaticSizes, leadingStaticStrides)
688           .cast<RankedTensorType>();
689   int rankDiff = inferredType.getRank() - resultRank;
690   if (rankDiff > 0) {
691     auto shape = inferredType.getShape();
692     llvm::SmallDenseSet<unsigned> dimsToProject;
693     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
694     SmallVector<int64_t> projectedShape;
695     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
696       if (!dimsToProject.contains(pos))
697         projectedShape.push_back(shape[pos]);
698     inferredType =
699         RankedTensorType::get(projectedShape, inferredType.getElementType());
700   }
701   return inferredType;
702 }
703 
704 Type ExtractSliceOp::inferRankReducedResultType(
705     unsigned resultRank, RankedTensorType sourceRankedTensorType,
706     ArrayRef<OpFoldResult> leadingStaticOffsets,
707     ArrayRef<OpFoldResult> leadingStaticSizes,
708     ArrayRef<OpFoldResult> leadingStaticStrides) {
709   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
710   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
711   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
712                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
713   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
714                              ShapedType::kDynamicSize);
715   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
716                              staticStrides, ShapedType::kDynamicStrideOrOffset);
717   return ExtractSliceOp::inferRankReducedResultType(
718       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
719       staticStrides);
720 }
721 
722 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
723 /// result type. If the type passed is nullptr, it is inferred.
724 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
725                            RankedTensorType resultType, Value source,
726                            ArrayRef<OpFoldResult> offsets,
727                            ArrayRef<OpFoldResult> sizes,
728                            ArrayRef<OpFoldResult> strides,
729                            ArrayRef<NamedAttribute> attrs) {
730   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
731   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
732   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
733                              ShapedType::kDynamicStrideOrOffset);
734   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
735                              ShapedType::kDynamicSize);
736   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
737                              ShapedType::kDynamicStrideOrOffset);
738   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
739   // Structuring implementation this way avoids duplication between builders.
740   if (!resultType) {
741     resultType =
742         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
743                                         staticSizes, staticStrides)
744             .cast<RankedTensorType>();
745   }
746   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
747         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
748         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
749   result.addAttributes(attrs);
750 }
751 
752 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
753 /// result type.
754 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
755                            ArrayRef<OpFoldResult> offsets,
756                            ArrayRef<OpFoldResult> sizes,
757                            ArrayRef<OpFoldResult> strides,
758                            ArrayRef<NamedAttribute> attrs) {
759   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
760 }
761 
762 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
763 /// type passed is nullptr, it is inferred.
764 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
765                            RankedTensorType resultType, Value source,
766                            ValueRange offsets, ValueRange sizes,
767                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
768   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
769       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
770   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
771       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
772   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
773       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
774   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
775 }
776 
777 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
778 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
779                            ValueRange offsets, ValueRange sizes,
780                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
781   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
782 }
783 
784 enum SliceVerificationResult {
785   Success,
786   RankTooLarge,
787   SizeMismatch,
788   ElemTypeMismatch,
789 };
790 
791 /// Checks if `original` Type type can be rank reduced to `reduced` type.
792 /// This function is slight variant of `is subsequence` algorithm where
793 /// not matching dimension must be 1.
794 static SliceVerificationResult
795 isRankReducedType(Type originalType, Type candidateReducedType,
796                   std::string *errMsg = nullptr) {
797   if (originalType == candidateReducedType)
798     return SliceVerificationResult::Success;
799   if (!originalType.isa<RankedTensorType>())
800     return SliceVerificationResult::Success;
801   if (originalType.isa<RankedTensorType>() &&
802       !candidateReducedType.isa<RankedTensorType>())
803     return SliceVerificationResult::Success;
804 
805   ShapedType originalShapedType = originalType.cast<ShapedType>();
806   ShapedType candidateReducedShapedType =
807       candidateReducedType.cast<ShapedType>();
808 
809   // Rank and size logic is valid for all ShapedTypes.
810   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
811   ArrayRef<int64_t> candidateReducedShape =
812       candidateReducedShapedType.getShape();
813   unsigned originalRank = originalShape.size(),
814            candidateReducedRank = candidateReducedShape.size();
815   if (candidateReducedRank > originalRank)
816     return SliceVerificationResult::RankTooLarge;
817 
818   auto optionalUnusedDimsMask =
819       computeRankReductionMask(originalShape, candidateReducedShape);
820 
821   // Sizes cannot be matched in case empty vector is returned.
822   if (!optionalUnusedDimsMask.hasValue())
823     return SliceVerificationResult::SizeMismatch;
824 
825   if (originalShapedType.getElementType() !=
826       candidateReducedShapedType.getElementType())
827     return SliceVerificationResult::ElemTypeMismatch;
828 
829   // We are done for the tensor case.
830   if (originalType.isa<RankedTensorType>())
831     return SliceVerificationResult::Success;
832 
833   return SliceVerificationResult::Success;
834 }
835 
836 template <typename OpTy>
837 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
838                                           OpTy op, Type expectedType,
839                                           StringRef errMsg = "") {
840   auto memrefType = expectedType.cast<ShapedType>();
841   switch (result) {
842   case SliceVerificationResult::Success:
843     return success();
844   case SliceVerificationResult::RankTooLarge:
845     return op.emitError("expected result rank to be smaller or equal to ")
846            << "the source rank. " << errMsg;
847   case SliceVerificationResult::SizeMismatch:
848     return op.emitError("expected result type to be ")
849            << expectedType
850            << " or a rank-reduced version. (mismatch of result sizes) "
851            << errMsg;
852   case SliceVerificationResult::ElemTypeMismatch:
853     return op.emitError("expected result element type to be ")
854            << memrefType.getElementType() << errMsg;
855   }
856   llvm_unreachable("unexpected extract_slice op verification result");
857 }
858 
859 /// Verifier for ExtractSliceOp.
860 static LogicalResult verify(ExtractSliceOp op) {
861   // Verify result type against inferred type.
862   auto expectedType = ExtractSliceOp::inferResultType(
863       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
864       extractFromI64ArrayAttr(op.static_sizes()),
865       extractFromI64ArrayAttr(op.static_strides()));
866   auto result = isRankReducedType(expectedType, op.getType());
867   return produceSliceErrorMsg(result, op, expectedType);
868 }
869 
870 /// Infer the canonical type of the result of an extract_slice op. Returns a
871 /// type with rank `resultRank` that is either the rank of the rank-reduced
872 /// type, or the non-rank-reduced type.
873 static RankedTensorType
874 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
875                             ArrayRef<OpFoldResult> mixedOffsets,
876                             ArrayRef<OpFoldResult> mixedSizes,
877                             ArrayRef<OpFoldResult> mixedStrides) {
878   auto resultType =
879       ExtractSliceOp::inferRankReducedResultType(
880           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
881           .cast<RankedTensorType>();
882   if (resultType.getRank() != resultRank) {
883     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
884                                                  mixedSizes, mixedStrides)
885                      .cast<RankedTensorType>();
886   }
887   return resultType;
888 }
889 
890 namespace {
891 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
892 /// This essentially pushes memref_cast past its consuming slice when
893 /// `canFoldIntoConsumerOp` is true.
894 ///
895 /// Example:
896 /// ```
897 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
898 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
899 ///   tensor<3x4xf32>
900 /// ```
901 /// is rewritten into:
902 /// ```
903 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
904 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
905 /// ```
906 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
907 public:
908   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
909 
910   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
911                                 PatternRewriter &rewriter) const override {
912     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
913     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
914           return matchPattern(operand, matchConstantIndex());
915         }))
916       return failure();
917 
918     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
919     if (!castOp)
920       return failure();
921 
922     if (!canFoldIntoConsumerOp(castOp))
923       return failure();
924 
925     /// Deduce the type of the result to use for the canonicalized operation.
926     RankedTensorType resultType = getCanonicalSliceResultType(
927         sliceOp.getType().getRank(), sliceOp.getSourceType(),
928         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
929         sliceOp.getMixedStrides());
930     Value newSlice = rewriter.create<ExtractSliceOp>(
931         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
932         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
933         sliceOp.static_sizes(), sliceOp.static_strides());
934     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
935                                                 newSlice);
936     return success();
937   }
938 };
939 } // namespace
940 
941 /// Return the canonical type of the result of an extract_slice op.
942 struct SliceReturnTypeCanonicalizer {
943   RankedTensorType operator()(ExtractSliceOp op,
944                               ArrayRef<OpFoldResult> mixedOffsets,
945                               ArrayRef<OpFoldResult> mixedSizes,
946                               ArrayRef<OpFoldResult> mixedStrides) {
947     return getCanonicalSliceResultType(op.getType().getRank(),
948                                        op.getSourceType(), mixedOffsets,
949                                        mixedSizes, mixedStrides);
950   }
951 };
952 
953 /// A canonicalizer wrapper to replace ExtractSliceOps.
954 struct SliceCanonicalizer {
955   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
956                   ExtractSliceOp newOp) {
957     Value replacement = newOp.getResult();
958     if (replacement.getType() != op.getType())
959       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
960                                                     replacement);
961     rewriter.replaceOp(op, replacement);
962   }
963 };
964 
965 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
966                                                  MLIRContext *context) {
967   results.add<
968       OpWithOffsetSizesAndStridesConstantArgumentFolder<
969           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
970       ExtractSliceOpCastFolder>(context);
971 }
972 
973 //
974 static LogicalResult
975 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
976                                            ShapedType shapedType) {
977   OpBuilder b(op.getContext());
978   for (OpFoldResult ofr : op.getMixedOffsets())
979     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
980       return failure();
981   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
982   // is appropriate.
983   auto shape = shapedType.getShape();
984   for (auto it : llvm::zip(op.getMixedSizes(), shape))
985     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
986       return failure();
987   for (OpFoldResult ofr : op.getMixedStrides())
988     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
989       return failure();
990   return success();
991 }
992 
993 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
994   if (getSourceType() == getType() &&
995       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
996     return this->source();
997   return OpFoldResult();
998 }
999 
1000 //===----------------------------------------------------------------------===//
1001 // InsertSliceOp
1002 //===----------------------------------------------------------------------===//
1003 
1004 // Build a InsertSliceOp with mixed static and dynamic entries.
1005 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1006                           Value dest, ArrayRef<OpFoldResult> offsets,
1007                           ArrayRef<OpFoldResult> sizes,
1008                           ArrayRef<OpFoldResult> strides,
1009                           ArrayRef<NamedAttribute> attrs) {
1010   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1011   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1012   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1013                              ShapedType::kDynamicStrideOrOffset);
1014   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1015                              ShapedType::kDynamicSize);
1016   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1017                              ShapedType::kDynamicStrideOrOffset);
1018   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1019         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1020         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1021   result.addAttributes(attrs);
1022 }
1023 
1024 // Build a InsertSliceOp with dynamic entries.
1025 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1026                           Value dest, ValueRange offsets, ValueRange sizes,
1027                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1028   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1029       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1030   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1031       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1032   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1033       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1034   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1035 }
1036 
1037 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1038   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1039       getSourceType() == getType() &&
1040       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1041     return this->source();
1042   return OpFoldResult();
1043 }
1044 
1045 LogicalResult InsertSliceOp::reifyResultShapes(
1046     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1047   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1048   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1049     reifiedReturnShapes[0][dim] =
1050         builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
1051   }
1052   return success();
1053 }
1054 
1055 namespace {
1056 /// Pattern to rewrite a insert_slice op with constant arguments.
1057 class InsertSliceOpConstantArgumentFolder final
1058     : public OpRewritePattern<InsertSliceOp> {
1059 public:
1060   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1061 
1062   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1063                                 PatternRewriter &rewriter) const override {
1064     // No constant operand, just return.
1065     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1066           return matchPattern(operand, matchConstantIndex());
1067         }))
1068       return failure();
1069 
1070     // At least one of offsets/sizes/strides is a new constant.
1071     // Form the new list of operands and constant attributes from the
1072     // existing.
1073     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1074     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1075     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1076     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1077     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1078     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1079 
1080     // Create the new op in canonical form.
1081     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1082         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
1083         mixedOffsets, mixedSizes, mixedStrides);
1084     return success();
1085   }
1086 };
1087 
1088 /// Fold tensor_casts with insert_slice operations. If the source or destination
1089 /// tensor is a tensor_cast that removes static type information, the cast is
1090 /// folded into the insert_slice operation. E.g.:
1091 ///
1092 /// ```mlir
1093 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1094 ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1095 /// ```
1096 ///
1097 /// folds into:
1098 ///
1099 /// ```mlir
1100 ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1101 /// ```
1102 ///
1103 /// Note: When folding a cast on the destination tensor, the result of the
1104 /// insert_slice operation is casted to ensure that the type of the result did
1105 /// not change.
1106 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1107   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1108 
1109   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1110                                 PatternRewriter &rewriter) const override {
1111     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1112           return matchPattern(operand, matchConstantIndex());
1113         }))
1114       return failure();
1115 
1116     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1117       auto castOp = v.getDefiningOp<tensor::CastOp>();
1118       if (!castOp || !canFoldIntoConsumerOp(castOp))
1119         return llvm::None;
1120       return castOp.source();
1121     };
1122     Optional<Value> sourceCastSource =
1123         getSourceOfCastOp(insertSliceOp.source());
1124     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1125     if (!sourceCastSource && !destCastSource)
1126       return failure();
1127 
1128     Value replacement = rewriter.create<InsertSliceOp>(
1129         insertSliceOp.getLoc(),
1130         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1131         (destCastSource ? *destCastSource : insertSliceOp.dest()),
1132         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1133         insertSliceOp.getMixedStrides());
1134 
1135     if (replacement.getType() != insertSliceOp.getType()) {
1136       replacement = rewriter.create<tensor::CastOp>(
1137           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1138     }
1139     rewriter.replaceOp(insertSliceOp, replacement);
1140     return success();
1141   }
1142 };
1143 
1144 /// If additional static type information can be deduced from a insert_slice's
1145 /// size operands, insert an explicit cast of the op's source operand. This
1146 /// enables other canonicalization patterns that are matching for tensor_cast
1147 /// ops such as `ForOpTensorCastFolder` in SCF.
1148 ///
1149 /// Example:
1150 ///
1151 /// ```mlir
1152 ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1153 ///       : tensor<?x?xf32> into ...
1154 /// ```
1155 ///
1156 /// folds into:
1157 ///
1158 /// ```mlir
1159 ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1160 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1161 ///       : tensor<64x64xf32> into ...
1162 /// ```
1163 struct InsertSliceOpSourceCastInserter final
1164     : public OpRewritePattern<InsertSliceOp> {
1165   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1166 
1167   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1168                                 PatternRewriter &rewriter) const override {
1169     RankedTensorType srcType = insertSliceOp.getSourceType();
1170     if (srcType.getRank() != insertSliceOp.getType().getRank())
1171       return failure();
1172     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1173                                      srcType.getShape().end());
1174     for (int64_t i = 0; i < srcType.getRank(); ++i) {
1175       if (Optional<int64_t> constInt =
1176               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1177         newSrcShape[i] = *constInt;
1178     }
1179     RankedTensorType newSrcType =
1180         RankedTensorType::get(newSrcShape, srcType.getElementType());
1181     if (srcType == newSrcType)
1182       return failure();
1183 
1184     // srcType and newSrcType are different. Insert a cast.
1185     Value cast = rewriter.create<tensor::CastOp>(
1186         insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
1187     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1188         insertSliceOp, cast, insertSliceOp.dest(),
1189         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1190         insertSliceOp.getMixedStrides());
1191     return success();
1192   }
1193 };
1194 } // namespace
1195 
1196 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1197                                                 MLIRContext *context) {
1198   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1199               InsertSliceOpSourceCastInserter>(context);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // TableGen'd op method definitions
1204 //===----------------------------------------------------------------------===//
1205 
1206 #define GET_OP_CLASSES
1207 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1208