1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Utils/StaticValueUtils.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 /// Materialize a single constant operation from a given attribute value with
23 /// the desired resultant type.
24 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
25                                               Attribute value, Type type,
26                                               Location loc) {
27   return builder.create<mlir::ConstantOp>(loc, type, value);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // CastOp
32 //===----------------------------------------------------------------------===//
33 
34 /// Returns true if `target` is a ranked tensor type that preserves static
35 /// information available in the `source` ranked tensor type.
36 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
37   auto sourceType = source.dyn_cast<RankedTensorType>();
38   auto targetType = target.dyn_cast<RankedTensorType>();
39 
40   // Requires RankedTensorType.
41   if (!sourceType || !targetType)
42     return false;
43 
44   // Requires same elemental type.
45   if (sourceType.getElementType() != targetType.getElementType())
46     return false;
47 
48   // Requires same rank.
49   if (sourceType.getRank() != targetType.getRank())
50     return false;
51 
52   // If cast is towards more static sizes along any dimension, don't fold.
53   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
54     if (!ShapedType::isDynamic(std::get<0>(t)) &&
55         ShapedType::isDynamic(std::get<1>(t)))
56       return false;
57   }
58 
59   return true;
60 }
61 
62 /// Determines whether tensor::CastOp casts to a more dynamic version of the
63 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
64 /// implement canonicalization patterns for ops in different dialects that may
65 /// consume the results of tensor.cast operations. Such foldable tensor.cast
66 /// operations are typically inserted as `slice` ops and are canonicalized,
67 /// to preserve the type compatibility of their uses.
68 ///
69 /// Returns true when all conditions are met:
70 /// 1. source and result are ranked tensors with same element type and rank.
71 /// 2. the tensor type has more static information than the result
72 ///
73 /// Example:
74 /// ```mlir
75 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
76 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
77 /// ```
78 ///
79 /// folds into:
80 ///
81 /// ```mlir
82 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
83 /// ```
84 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
85   if (!castOp)
86     return false;
87 
88   // Can fold if the source of cast has at least as much static information as
89   // its results.
90   return preservesStaticInformation(castOp.getType(),
91                                     castOp.source().getType());
92 }
93 
94 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
95 /// that can be folded.
96 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
97   bool folded = false;
98   for (OpOperand &operand : op->getOpOperands()) {
99     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
100     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
101       operand.set(castOp.getOperand());
102       folded = true;
103     }
104   }
105   return success(folded);
106 }
107 
108 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
109   if (inputs.size() != 1 || outputs.size() != 1)
110     return false;
111   Type a = inputs.front(), b = outputs.front();
112   auto aT = a.dyn_cast<TensorType>();
113   auto bT = b.dyn_cast<TensorType>();
114   if (!aT || !bT)
115     return false;
116 
117   if (aT.getElementType() != bT.getElementType())
118     return false;
119 
120   return succeeded(verifyCompatibleShape(aT, bT));
121 }
122 
123 /// Compute a TensorType that has the joined shape knowledge of the two
124 /// given TensorTypes. The element types need to match.
125 static TensorType joinShapes(TensorType one, TensorType two) {
126   assert(one.getElementType() == two.getElementType());
127 
128   if (!one.hasRank())
129     return two;
130   if (!two.hasRank())
131     return one;
132 
133   int64_t rank = one.getRank();
134   if (rank != two.getRank())
135     return {};
136 
137   SmallVector<int64_t, 4> join;
138   join.reserve(rank);
139   for (int64_t i = 0; i < rank; ++i) {
140     if (one.isDynamicDim(i)) {
141       join.push_back(two.getDimSize(i));
142       continue;
143     }
144     if (two.isDynamicDim(i)) {
145       join.push_back(one.getDimSize(i));
146       continue;
147     }
148     if (one.getDimSize(i) != two.getDimSize(i))
149       return {};
150     join.push_back(one.getDimSize(i));
151   }
152   return RankedTensorType::get(join, one.getElementType());
153 }
154 
155 namespace {
156 
157 /// Replaces chains of two tensor.cast operations by a single tensor.cast
158 /// operation if doing so does not remove runtime constraints.
159 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
160   using OpRewritePattern<CastOp>::OpRewritePattern;
161 
162   LogicalResult matchAndRewrite(CastOp tensorCast,
163                                 PatternRewriter &rewriter) const final {
164     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
165 
166     if (!tensorCastOperand)
167       return failure();
168 
169     auto sourceType =
170         tensorCastOperand.getOperand().getType().cast<TensorType>();
171     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
172     auto resultType = tensorCast.getType().cast<TensorType>();
173 
174     // We can remove the intermediate cast if joining all three produces the
175     // same result as just joining the source and result shapes.
176     auto firstJoin =
177         joinShapes(joinShapes(sourceType, intermediateType), resultType);
178 
179     // The join might not exist if the cast sequence would fail at runtime.
180     if (!firstJoin)
181       return failure();
182 
183     // The newJoin always exists if the above join exists, it might just contain
184     // less information. If so, we cannot drop the intermediate cast, as doing
185     // so would remove runtime checks.
186     auto newJoin = joinShapes(sourceType, resultType);
187     if (firstJoin != newJoin)
188       return failure();
189 
190     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
191                                         tensorCastOperand.getOperand());
192     return success();
193   }
194 };
195 
196 } // namespace
197 
198 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
199                                          MLIRContext *context) {
200   results.add<ChainedTensorCast>(context);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // DimOp
205 //===----------------------------------------------------------------------===//
206 
207 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
208                   int64_t index) {
209   auto loc = result.location;
210   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
211   build(builder, result, source, indexValue);
212 }
213 
214 Optional<int64_t> DimOp::getConstantIndex() {
215   if (auto constantOp = index().getDefiningOp<ConstantOp>())
216     return constantOp.getValue().cast<IntegerAttr>().getInt();
217   return {};
218 }
219 
220 static LogicalResult verify(DimOp op) {
221   // Assume unknown index to be in range.
222   Optional<int64_t> index = op.getConstantIndex();
223   if (!index.hasValue())
224     return success();
225 
226   // Check that constant index is not knowingly out of range.
227   auto type = op.source().getType();
228   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
229     if (index.getValue() >= tensorType.getRank())
230       return op.emitOpError("index is out of range");
231   } else if (type.isa<UnrankedTensorType>()) {
232     // Assume index to be in range.
233   } else {
234     llvm_unreachable("expected operand with tensor type");
235   }
236   return success();
237 }
238 
239 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
240   // All forms of folding require a known index.
241   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
242   if (!index)
243     return {};
244 
245   // Folding for unranked types (UnrankedTensorType) is not supported.
246   auto tensorType = source().getType().dyn_cast<RankedTensorType>();
247   if (!tensorType)
248     return {};
249 
250   // Fold if the shape extent along the given index is known.
251   if (!tensorType.isDynamicDim(index.getInt())) {
252     Builder builder(getContext());
253     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
254   }
255 
256   Operation *definingOp = source().getDefiningOp();
257 
258   // Fold dim to the operand of tensor.generate.
259   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
260     auto resultType =
261         fromElements.getResult().getType().cast<RankedTensorType>();
262     // The case where the type encodes the size of the dimension is handled
263     // above.
264     assert(resultType.getShape()[index.getInt()] ==
265            RankedTensorType::kDynamicSize);
266 
267     // Find the operand of the fromElements that corresponds to this index.
268     auto dynExtents = fromElements.dynamicExtents().begin();
269     for (auto dim : resultType.getShape().take_front(index.getInt()))
270       if (dim == RankedTensorType::kDynamicSize)
271         dynExtents++;
272 
273     return Value{*dynExtents};
274   }
275 
276   // The size at the given index is now known to be a dynamic size.
277   unsigned unsignedIndex = index.getValue().getZExtValue();
278 
279   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
280     // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
281     // `resolve-shaped-type-result-dims` pass.
282     if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
283         sliceOp.isDynamicSize(unsignedIndex)) {
284       return {sliceOp.getDynamicSize(unsignedIndex)};
285     }
286   }
287 
288   // dim(cast) -> dim
289   if (succeeded(foldTensorCast(*this)))
290     return getResult();
291 
292   return {};
293 }
294 
295 namespace {
296 /// Fold dim of a cast into the dim of the source of the tensor cast.
297 struct DimOfCastOp : public OpRewritePattern<DimOp> {
298   using OpRewritePattern<DimOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(DimOp dimOp,
301                                 PatternRewriter &rewriter) const override {
302     auto castOp = dimOp.source().getDefiningOp<CastOp>();
303     if (!castOp)
304       return failure();
305     Value newSource = castOp.getOperand();
306     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
307     return success();
308   }
309 };
310 } // end anonymous namespace.
311 
312 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
313                                         MLIRContext *context) {
314   results.add<DimOfCastOp>(context);
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // ExtractOp
319 //===----------------------------------------------------------------------===//
320 
321 static LogicalResult verify(ExtractOp op) {
322   // Verify the # indices match if we have a ranked type.
323   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
324     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
325       return op.emitOpError("incorrect number of indices for extract_element");
326 
327   return success();
328 }
329 
330 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
331   // The tensor operand must be a known constant.
332   Attribute tensor = operands.front();
333   if (!tensor)
334     return {};
335   // If this is a splat elements attribute, simply return the value. All of the
336   // elements of a splat attribute are the same.
337   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
338     return splatTensor.getSplatValue();
339 
340   // Otherwise, collect the constant indices into the tensor.
341   SmallVector<uint64_t, 8> indices;
342   for (Attribute indice : llvm::drop_begin(operands, 1)) {
343     if (!indice || !indice.isa<IntegerAttr>())
344       return {};
345     indices.push_back(indice.cast<IntegerAttr>().getInt());
346   }
347 
348   // If this is an elements attribute, query the value at the given indices.
349   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
350   if (elementsAttr && elementsAttr.isValidIndex(indices))
351     return elementsAttr.getValue(indices);
352   return {};
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // FromElementsOp
357 //===----------------------------------------------------------------------===//
358 
359 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
360                            Type elementType, ValueRange elements) {
361   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
362                                         elementType);
363   result.addOperands(elements);
364   result.addTypes(resultTy);
365 }
366 
367 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
368                            ValueRange elements) {
369   assert(!elements.empty() && "expected at least one element");
370   build(builder, result, elements.front().getType(), elements);
371 }
372 
373 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
374   if (!llvm::is_contained(operands, nullptr))
375     return DenseElementsAttr::get(getType(), operands);
376   return {};
377 }
378 
379 namespace {
380 
381 // Canonicalizes the pattern of the form
382 //
383 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
384 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
385 //
386 // to just %element.
387 struct ExtractElementFromTensorFromElements
388     : public OpRewritePattern<tensor::ExtractOp> {
389   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
390 
391   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
392                                 PatternRewriter &rewriter) const final {
393     if (extract.indices().size() != 1)
394       return failure();
395 
396     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
397     if (tensorFromElements == nullptr)
398       return failure();
399 
400     APInt index;
401     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
402       return failure();
403     // Prevent out of bounds accesses. This can happen in invalid code that will
404     // never execute.
405     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
406         index.getSExtValue() < 0)
407       return failure();
408     rewriter.replaceOp(extract,
409                        tensorFromElements.getOperand(index.getZExtValue()));
410     return success();
411   }
412 };
413 
414 } // namespace
415 
416 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
417                                                  MLIRContext *context) {
418   results.add<ExtractElementFromTensorFromElements>(context);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // InsertOp
423 //===----------------------------------------------------------------------===//
424 
425 static LogicalResult verify(InsertOp op) {
426   // Verify the # indices match if we have a ranked type.
427   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
428     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
429       return op.emitOpError("incorrect number of indices");
430   return success();
431 }
432 
433 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
434   Attribute scalar = operands[0];
435   Attribute dest = operands[1];
436   if (scalar && dest)
437     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
438       if (scalar == splatDest.getSplatValue())
439         return dest;
440   return {};
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // GenerateOp
445 //===----------------------------------------------------------------------===//
446 
447 static LogicalResult verify(GenerateOp op) {
448   // Ensure that the tensor type has as many dynamic dimensions as are specified
449   // by the operands.
450   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
451   if (op.getNumOperands() != resultTy.getNumDynamicDims())
452     return op.emitError("must have as many index operands as dynamic extents "
453                         "in the result type");
454 
455   // Ensure that region arguments span the index space.
456   if (!llvm::all_of(op.body().getArgumentTypes(),
457                     [](Type ty) { return ty.isIndex(); }))
458     return op.emitError("all body arguments must be index");
459   if (op.body().getNumArguments() != resultTy.getRank())
460     return op.emitError("must have one body argument per input dimension");
461 
462   // Ensure that the region yields an element of the right type.
463   auto yieldOp =
464       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
465   if (yieldOp.value().getType() != resultTy.getElementType())
466     return op.emitOpError(
467         "body must be terminated with a `yield` operation of the tensor "
468         "element type");
469 
470   return success();
471 }
472 
473 void GenerateOp::build(
474     OpBuilder &b, OperationState &result, Type resultTy,
475     ValueRange dynamicExtents,
476     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
477   build(b, result, resultTy, dynamicExtents);
478 
479   // Build and populate body.
480   OpBuilder::InsertionGuard guard(b);
481   Region *bodyRegion = result.regions.front().get();
482   auto rank = resultTy.cast<RankedTensorType>().getRank();
483   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
484   Block *bodyBlock =
485       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
486   bodyBuilder(b, result.location, bodyBlock->getArguments());
487 }
488 
489 namespace {
490 
491 /// Canonicalizes tensor.generate operations with a constant
492 /// operand into the equivalent operation with the operand expressed in the
493 /// result type, instead. We also insert a type cast to make sure that the
494 /// resulting IR is still well-typed.
495 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
496   using OpRewritePattern<GenerateOp>::OpRewritePattern;
497 
498   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
499                                 PatternRewriter &rewriter) const final {
500     auto resultType =
501         tensorFromElements.getResult().getType().cast<RankedTensorType>();
502 
503     if (resultType.hasStaticShape())
504       return failure();
505 
506     SmallVector<Value, 4> newOperands;
507     SmallVector<int64_t, 4> newShape;
508     auto operandsIt = tensorFromElements.dynamicExtents().begin();
509 
510     for (int64_t dim : resultType.getShape()) {
511       if (dim != RankedTensorType::kDynamicSize) {
512         newShape.push_back(dim);
513         continue;
514       }
515       APInt index;
516       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
517         newShape.push_back(RankedTensorType::kDynamicSize);
518         newOperands.push_back(*operandsIt++);
519         continue;
520       }
521       newShape.push_back(index.getSExtValue());
522       operandsIt++;
523     }
524 
525     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
526       return failure();
527 
528     auto loc = tensorFromElements.getLoc();
529     auto newOp = rewriter.create<GenerateOp>(
530         loc, RankedTensorType::get(newShape, resultType.getElementType()),
531         newOperands);
532     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
533                                 newOp.body().begin());
534     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
535                                                 newOp);
536     return success();
537   }
538 };
539 
540 /// Canonicalizes the pattern of the form
541 ///
542 /// %tensor = tensor.generate %x {
543 ///   ^bb0(%arg0: index):  // no predecessors
544 ///   <computation>
545 ///   yield %1 : index
546 /// } : tensor<?xindex>
547 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
548 ///
549 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
550 /// tensor.generate operation has no side-effects.
551 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
552   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
553 
554   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
555                                 PatternRewriter &rewriter) const final {
556     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
557     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
558       return failure();
559 
560     BlockAndValueMapping mapping;
561     Block *body = tensorFromElements.getBody();
562     mapping.map(body->getArguments(), extract.indices());
563     for (auto &op : body->without_terminator())
564       rewriter.clone(op, mapping);
565 
566     auto yield = cast<YieldOp>(body->getTerminator());
567 
568     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
569     return success();
570   }
571 };
572 
573 /// Canonicalizes the pattern of the form
574 ///
575 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
576 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
577 ///
578 /// to
579 ///
580 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
581 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
582   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
583 
584   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
585                                 PatternRewriter &rewriter) const final {
586     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
587     if (!tensorCast)
588       return failure();
589 
590     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
591                                                    extract.indices());
592     return success();
593   }
594 };
595 
596 } // namespace
597 
598 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
599                                              MLIRContext *context) {
600   // TODO: Move extract patterns to tensor::ExtractOp.
601   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
602               StaticTensorGenerate>(context);
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // ReshapeOp
607 //===----------------------------------------------------------------------===//
608 
609 static int64_t GetNumElements(ShapedType type) {
610   int64_t numElements = 1;
611   for (auto dim : type.getShape())
612     numElements *= dim;
613   return numElements;
614 }
615 
616 static LogicalResult verify(ReshapeOp op) {
617   TensorType operandType = op.source().getType().cast<TensorType>();
618   TensorType resultType = op.result().getType().cast<TensorType>();
619 
620   if (operandType.getElementType() != resultType.getElementType())
621     return op.emitOpError("element types of source and destination tensor "
622                           "types should be the same");
623 
624   int64_t shapeSize =
625       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
626   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
627   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
628 
629   if (resultRankedType) {
630     if (operandRankedType && resultRankedType.hasStaticShape() &&
631         operandRankedType.hasStaticShape()) {
632       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
633         return op.emitOpError("source and destination tensor should have the "
634                               "same number of elements");
635     }
636     if (shapeSize == TensorType::kDynamicSize)
637       return op.emitOpError("cannot use shape operand with dynamic length to "
638                             "reshape to statically-ranked tensor type");
639     if (shapeSize != resultRankedType.getRank())
640       return op.emitOpError(
641           "length of shape operand differs from the result's tensor rank");
642   }
643   return success();
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // ExtractSliceOp
648 //===----------------------------------------------------------------------===//
649 
650 /// An extract_slice op result type can be fully inferred from the source type
651 /// and the static representation of offsets, sizes and strides. Special
652 /// sentinels encode the dynamic case.
653 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
654                                      ArrayRef<int64_t> leadingStaticOffsets,
655                                      ArrayRef<int64_t> leadingStaticSizes,
656                                      ArrayRef<int64_t> leadingStaticStrides) {
657   // An extract_slice op may specify only a leading subset of offset/sizes/
658   // strides in which case we complete with offset=0, sizes from memref type and
659   // strides=1.
660   unsigned rank = sourceRankedTensorType.getRank();
661   assert(leadingStaticSizes.size() <= rank &&
662          "unexpected leadingStaticSizes overflow");
663   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
664   unsigned numTrailingSizes = rank - staticSizes.size();
665   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
666                                       numTrailingSizes));
667   return RankedTensorType::get(staticSizes,
668                                sourceRankedTensorType.getElementType());
669 }
670 
671 Type ExtractSliceOp::inferResultType(
672     RankedTensorType sourceRankedTensorType,
673     ArrayRef<OpFoldResult> leadingStaticOffsets,
674     ArrayRef<OpFoldResult> leadingStaticSizes,
675     ArrayRef<OpFoldResult> leadingStaticStrides) {
676   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
677   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
678   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
679                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
680   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
681                              ShapedType::kDynamicSize);
682   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
683                              staticStrides, ShapedType::kDynamicStrideOrOffset);
684   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
685                                          staticSizes, staticStrides);
686 }
687 
688 /// An extract_slice op result type can be fully inferred from the source type
689 /// and the static representation of offsets, sizes and strides. Special
690 /// sentinels encode the dynamic case.
691 Type ExtractSliceOp::inferRankReducedResultType(
692     unsigned resultRank, RankedTensorType sourceRankedTensorType,
693     ArrayRef<int64_t> leadingStaticOffsets,
694     ArrayRef<int64_t> leadingStaticSizes,
695     ArrayRef<int64_t> leadingStaticStrides) {
696   auto inferredType =
697       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
698                       leadingStaticSizes, leadingStaticStrides)
699           .cast<RankedTensorType>();
700   int rankDiff = inferredType.getRank() - resultRank;
701   if (rankDiff > 0) {
702     auto shape = inferredType.getShape();
703     llvm::SmallDenseSet<unsigned> dimsToProject;
704     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
705     SmallVector<int64_t> projectedShape;
706     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
707       if (!dimsToProject.contains(pos))
708         projectedShape.push_back(shape[pos]);
709     inferredType =
710         RankedTensorType::get(projectedShape, inferredType.getElementType());
711   }
712   return inferredType;
713 }
714 
715 Type ExtractSliceOp::inferRankReducedResultType(
716     unsigned resultRank, RankedTensorType sourceRankedTensorType,
717     ArrayRef<OpFoldResult> leadingStaticOffsets,
718     ArrayRef<OpFoldResult> leadingStaticSizes,
719     ArrayRef<OpFoldResult> leadingStaticStrides) {
720   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
721   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
722   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
723                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
724   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
725                              ShapedType::kDynamicSize);
726   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
727                              staticStrides, ShapedType::kDynamicStrideOrOffset);
728   return ExtractSliceOp::inferRankReducedResultType(
729       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
730       staticStrides);
731 }
732 
733 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
734 /// result type. If the type passed is nullptr, it is inferred.
735 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
736                            RankedTensorType resultType, Value source,
737                            ArrayRef<OpFoldResult> offsets,
738                            ArrayRef<OpFoldResult> sizes,
739                            ArrayRef<OpFoldResult> strides,
740                            ArrayRef<NamedAttribute> attrs) {
741   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
742   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
743   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
744                              ShapedType::kDynamicStrideOrOffset);
745   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
746                              ShapedType::kDynamicSize);
747   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
748                              ShapedType::kDynamicStrideOrOffset);
749   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
750   // Structuring implementation this way avoids duplication between builders.
751   if (!resultType) {
752     resultType =
753         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
754                                         staticSizes, staticStrides)
755             .cast<RankedTensorType>();
756   }
757   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
758         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
759         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
760   result.addAttributes(attrs);
761 }
762 
763 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
764 /// result type.
765 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
766                            ArrayRef<OpFoldResult> offsets,
767                            ArrayRef<OpFoldResult> sizes,
768                            ArrayRef<OpFoldResult> strides,
769                            ArrayRef<NamedAttribute> attrs) {
770   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
771 }
772 
773 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
774 /// type passed is nullptr, it is inferred.
775 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
776                            RankedTensorType resultType, Value source,
777                            ValueRange offsets, ValueRange sizes,
778                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
779   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
780       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
781   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
782       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
783   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
784       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
785   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
786 }
787 
788 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
789 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
790                            ValueRange offsets, ValueRange sizes,
791                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
792   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
793 }
794 
795 enum SliceVerificationResult {
796   Success,
797   RankTooLarge,
798   SizeMismatch,
799   ElemTypeMismatch,
800 };
801 
802 /// Checks if `original` Type type can be rank reduced to `reduced` type.
803 /// This function is slight variant of `is subsequence` algorithm where
804 /// not matching dimension must be 1.
805 static SliceVerificationResult
806 isRankReducedType(Type originalType, Type candidateReducedType,
807                   std::string *errMsg = nullptr) {
808   if (originalType == candidateReducedType)
809     return SliceVerificationResult::Success;
810   if (!originalType.isa<RankedTensorType>())
811     return SliceVerificationResult::Success;
812   if (originalType.isa<RankedTensorType>() &&
813       !candidateReducedType.isa<RankedTensorType>())
814     return SliceVerificationResult::Success;
815 
816   ShapedType originalShapedType = originalType.cast<ShapedType>();
817   ShapedType candidateReducedShapedType =
818       candidateReducedType.cast<ShapedType>();
819 
820   // Rank and size logic is valid for all ShapedTypes.
821   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
822   ArrayRef<int64_t> candidateReducedShape =
823       candidateReducedShapedType.getShape();
824   unsigned originalRank = originalShape.size(),
825            candidateReducedRank = candidateReducedShape.size();
826   if (candidateReducedRank > originalRank)
827     return SliceVerificationResult::RankTooLarge;
828 
829   auto optionalUnusedDimsMask =
830       computeRankReductionMask(originalShape, candidateReducedShape);
831 
832   // Sizes cannot be matched in case empty vector is returned.
833   if (!optionalUnusedDimsMask.hasValue())
834     return SliceVerificationResult::SizeMismatch;
835 
836   if (originalShapedType.getElementType() !=
837       candidateReducedShapedType.getElementType())
838     return SliceVerificationResult::ElemTypeMismatch;
839 
840   // We are done for the tensor case.
841   if (originalType.isa<RankedTensorType>())
842     return SliceVerificationResult::Success;
843 
844   return SliceVerificationResult::Success;
845 }
846 
847 template <typename OpTy>
848 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
849                                           OpTy op, Type expectedType,
850                                           StringRef errMsg = "") {
851   auto memrefType = expectedType.cast<ShapedType>();
852   switch (result) {
853   case SliceVerificationResult::Success:
854     return success();
855   case SliceVerificationResult::RankTooLarge:
856     return op.emitError("expected result rank to be smaller or equal to ")
857            << "the source rank. " << errMsg;
858   case SliceVerificationResult::SizeMismatch:
859     return op.emitError("expected result type to be ")
860            << expectedType
861            << " or a rank-reduced version. (mismatch of result sizes) "
862            << errMsg;
863   case SliceVerificationResult::ElemTypeMismatch:
864     return op.emitError("expected result element type to be ")
865            << memrefType.getElementType() << errMsg;
866   }
867   llvm_unreachable("unexpected extract_slice op verification result");
868 }
869 
870 /// Verifier for ExtractSliceOp.
871 static LogicalResult verify(ExtractSliceOp op) {
872   // Verify result type against inferred type.
873   auto expectedType = ExtractSliceOp::inferResultType(
874       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
875       extractFromI64ArrayAttr(op.static_sizes()),
876       extractFromI64ArrayAttr(op.static_strides()));
877   auto result = isRankReducedType(expectedType, op.getType());
878   return produceSliceErrorMsg(result, op, expectedType);
879 }
880 
881 /// Infer the canonical type of the result of an extract_slice op. Returns a
882 /// type with rank `resultRank` that is either the rank of the rank-reduced
883 /// type, or the non-rank-reduced type.
884 static RankedTensorType
885 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
886                             ArrayRef<OpFoldResult> mixedOffsets,
887                             ArrayRef<OpFoldResult> mixedSizes,
888                             ArrayRef<OpFoldResult> mixedStrides) {
889   auto resultType =
890       ExtractSliceOp::inferRankReducedResultType(
891           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
892           .cast<RankedTensorType>();
893   if (resultType.getRank() != resultRank) {
894     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
895                                                  mixedSizes, mixedStrides)
896                      .cast<RankedTensorType>();
897   }
898   return resultType;
899 }
900 
901 llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
902   llvm::SmallDenseSet<unsigned> droppedDims;
903   ArrayRef<int64_t> resultShape = getType().getShape();
904   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
905   unsigned shapePos = 0;
906   for (auto size : enumerate(mixedSizes)) {
907     Optional<int64_t> sizeVal = getConstantIntValue(size.value());
908     // If the size is not 1, or if the current matched dimension of the result
909     // is the same static shape as the size value (which is 1), then the
910     // dimension is preserved.
911     if (!sizeVal || sizeVal.getValue() != 1 ||
912         (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
913       shapePos++;
914       continue;
915     }
916     droppedDims.insert(size.index());
917   }
918   return droppedDims;
919 }
920 
921 LogicalResult ExtractSliceOp::reifyResultShapes(
922     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
923   reifiedReturnShapes.resize(1);
924   reifiedReturnShapes[0].reserve(getType().getRank());
925   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
926   llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
927   Location loc = getLoc();
928   for (auto size : enumerate(mixedSizes)) {
929     if (droppedDims.count(size.index()))
930       continue;
931     if (auto attr = size.value().dyn_cast<Attribute>()) {
932       reifiedReturnShapes[0].push_back(builder.create<ConstantIndexOp>(
933           loc, attr.cast<IntegerAttr>().getInt()));
934       continue;
935     }
936     reifiedReturnShapes[0].push_back(size.value().get<Value>());
937   }
938   return success();
939 }
940 
941 namespace {
942 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
943 /// This essentially pushes memref_cast past its consuming slice when
944 /// `canFoldIntoConsumerOp` is true.
945 ///
946 /// Example:
947 /// ```
948 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
949 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
950 ///   tensor<3x4xf32>
951 /// ```
952 /// is rewritten into:
953 /// ```
954 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
955 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
956 /// ```
957 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
958 public:
959   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
960 
961   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
962                                 PatternRewriter &rewriter) const override {
963     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
964     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
965           return matchPattern(operand, matchConstantIndex());
966         }))
967       return failure();
968 
969     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
970     if (!castOp)
971       return failure();
972 
973     if (!canFoldIntoConsumerOp(castOp))
974       return failure();
975 
976     /// Deduce the type of the result to use for the canonicalized operation.
977     RankedTensorType resultType = getCanonicalSliceResultType(
978         sliceOp.getType().getRank(), sliceOp.getSourceType(),
979         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
980         sliceOp.getMixedStrides());
981     Value newSlice = rewriter.create<ExtractSliceOp>(
982         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
983         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
984         sliceOp.static_sizes(), sliceOp.static_strides());
985     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
986                                                 newSlice);
987     return success();
988   }
989 };
990 } // namespace
991 
992 /// Return the canonical type of the result of an extract_slice op.
993 struct SliceReturnTypeCanonicalizer {
994   RankedTensorType operator()(ExtractSliceOp op,
995                               ArrayRef<OpFoldResult> mixedOffsets,
996                               ArrayRef<OpFoldResult> mixedSizes,
997                               ArrayRef<OpFoldResult> mixedStrides) {
998     return getCanonicalSliceResultType(op.getType().getRank(),
999                                        op.getSourceType(), mixedOffsets,
1000                                        mixedSizes, mixedStrides);
1001   }
1002 };
1003 
1004 /// A canonicalizer wrapper to replace ExtractSliceOps.
1005 struct SliceCanonicalizer {
1006   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1007                   ExtractSliceOp newOp) {
1008     Value replacement = newOp.getResult();
1009     if (replacement.getType() != op.getType())
1010       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1011                                                     replacement);
1012     rewriter.replaceOp(op, replacement);
1013   }
1014 };
1015 
1016 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1017                                                  MLIRContext *context) {
1018   results.add<
1019       OpWithOffsetSizesAndStridesConstantArgumentFolder<
1020           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
1021       ExtractSliceOpCastFolder>(context);
1022 }
1023 
1024 //
1025 static LogicalResult
1026 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1027                                            ShapedType shapedType) {
1028   OpBuilder b(op.getContext());
1029   for (OpFoldResult ofr : op.getMixedOffsets())
1030     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1031       return failure();
1032   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1033   // is appropriate.
1034   auto shape = shapedType.getShape();
1035   for (auto it : llvm::zip(op.getMixedSizes(), shape))
1036     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1037       return failure();
1038   for (OpFoldResult ofr : op.getMixedStrides())
1039     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1040       return failure();
1041   return success();
1042 }
1043 
1044 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
1045   if (getSourceType() == getType() &&
1046       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1047     return this->source();
1048   return OpFoldResult();
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // InsertSliceOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 // Build a InsertSliceOp with mixed static and dynamic entries.
1056 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1057                           Value dest, ArrayRef<OpFoldResult> offsets,
1058                           ArrayRef<OpFoldResult> sizes,
1059                           ArrayRef<OpFoldResult> strides,
1060                           ArrayRef<NamedAttribute> attrs) {
1061   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1062   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1063   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1064                              ShapedType::kDynamicStrideOrOffset);
1065   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1066                              ShapedType::kDynamicSize);
1067   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1068                              ShapedType::kDynamicStrideOrOffset);
1069   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1070         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1071         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1072   result.addAttributes(attrs);
1073 }
1074 
1075 // Build a InsertSliceOp with dynamic entries.
1076 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1077                           Value dest, ValueRange offsets, ValueRange sizes,
1078                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1079   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1080       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1081   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1082       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1083   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1084       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1085   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1086 }
1087 
1088 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1089   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1090       getSourceType() == getType() &&
1091       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1092     return this->source();
1093   return OpFoldResult();
1094 }
1095 
1096 LogicalResult InsertSliceOp::reifyResultShapes(
1097     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1098   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1099   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1100     reifiedReturnShapes[0][dim] =
1101         builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
1102   }
1103   return success();
1104 }
1105 
1106 namespace {
1107 /// Pattern to rewrite a insert_slice op with constant arguments.
1108 class InsertSliceOpConstantArgumentFolder final
1109     : public OpRewritePattern<InsertSliceOp> {
1110 public:
1111   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1112 
1113   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1114                                 PatternRewriter &rewriter) const override {
1115     // No constant operand, just return.
1116     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1117           return matchPattern(operand, matchConstantIndex());
1118         }))
1119       return failure();
1120 
1121     // At least one of offsets/sizes/strides is a new constant.
1122     // Form the new list of operands and constant attributes from the
1123     // existing.
1124     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1125     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1126     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1127     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1128     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1129     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1130 
1131     // Create the new op in canonical form.
1132     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1133         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
1134         mixedOffsets, mixedSizes, mixedStrides);
1135     return success();
1136   }
1137 };
1138 
1139 /// Fold tensor_casts with insert_slice operations. If the source or destination
1140 /// tensor is a tensor_cast that removes static type information, the cast is
1141 /// folded into the insert_slice operation. E.g.:
1142 ///
1143 /// ```mlir
1144 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1145 ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1146 /// ```
1147 ///
1148 /// folds into:
1149 ///
1150 /// ```mlir
1151 ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1152 /// ```
1153 ///
1154 /// Note: When folding a cast on the destination tensor, the result of the
1155 /// insert_slice operation is casted to ensure that the type of the result did
1156 /// not change.
1157 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1158   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1159 
1160   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1161                                 PatternRewriter &rewriter) const override {
1162     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1163           return matchPattern(operand, matchConstantIndex());
1164         }))
1165       return failure();
1166 
1167     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1168       auto castOp = v.getDefiningOp<tensor::CastOp>();
1169       if (!castOp || !canFoldIntoConsumerOp(castOp))
1170         return llvm::None;
1171       return castOp.source();
1172     };
1173     Optional<Value> sourceCastSource =
1174         getSourceOfCastOp(insertSliceOp.source());
1175     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1176     if (!sourceCastSource && !destCastSource)
1177       return failure();
1178 
1179     Value replacement = rewriter.create<InsertSliceOp>(
1180         insertSliceOp.getLoc(),
1181         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1182         (destCastSource ? *destCastSource : insertSliceOp.dest()),
1183         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1184         insertSliceOp.getMixedStrides());
1185 
1186     if (replacement.getType() != insertSliceOp.getType()) {
1187       replacement = rewriter.create<tensor::CastOp>(
1188           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1189     }
1190     rewriter.replaceOp(insertSliceOp, replacement);
1191     return success();
1192   }
1193 };
1194 
1195 /// If additional static type information can be deduced from a insert_slice's
1196 /// size operands, insert an explicit cast of the op's source operand. This
1197 /// enables other canonicalization patterns that are matching for tensor_cast
1198 /// ops such as `ForOpTensorCastFolder` in SCF.
1199 ///
1200 /// Example:
1201 ///
1202 /// ```mlir
1203 ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1204 ///       : tensor<?x?xf32> into ...
1205 /// ```
1206 ///
1207 /// folds into:
1208 ///
1209 /// ```mlir
1210 ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1211 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1212 ///       : tensor<64x64xf32> into ...
1213 /// ```
1214 struct InsertSliceOpSourceCastInserter final
1215     : public OpRewritePattern<InsertSliceOp> {
1216   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1217 
1218   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1219                                 PatternRewriter &rewriter) const override {
1220     RankedTensorType srcType = insertSliceOp.getSourceType();
1221     if (srcType.getRank() != insertSliceOp.getType().getRank())
1222       return failure();
1223     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1224                                      srcType.getShape().end());
1225     for (int64_t i = 0; i < srcType.getRank(); ++i) {
1226       if (Optional<int64_t> constInt =
1227               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1228         newSrcShape[i] = *constInt;
1229     }
1230     RankedTensorType newSrcType =
1231         RankedTensorType::get(newSrcShape, srcType.getElementType());
1232     if (srcType == newSrcType)
1233       return failure();
1234 
1235     // srcType and newSrcType are different. Insert a cast.
1236     Value cast = rewriter.create<tensor::CastOp>(
1237         insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
1238     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1239         insertSliceOp, cast, insertSliceOp.dest(),
1240         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1241         insertSliceOp.getMixedStrides());
1242     return success();
1243   }
1244 };
1245 } // namespace
1246 
1247 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248                                                 MLIRContext *context) {
1249   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1250               InsertSliceOpSourceCastInserter>(context);
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // TableGen'd op method definitions
1255 //===----------------------------------------------------------------------===//
1256 
1257 #define GET_OP_CLASSES
1258 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1259