1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Utils/StaticValueUtils.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 //===----------------------------------------------------------------------===//
23 // CastOp
24 //===----------------------------------------------------------------------===//
25 
26 /// Determines whether tensor::CastOp casts to a more dynamic version of the
27 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
28 /// implement canonicalization patterns for ops in different dialects that may
29 /// consume the results of tensor.cast operations. Such foldable tensor.cast
30 /// operations are typically inserted as `slice` ops and are canonicalized,
31 /// to preserve the type compatibility of their uses.
32 ///
33 /// Returns true when all conditions are met:
34 /// 1. source and result are ranked tensors with same element type and rank.
35 /// 2. the tensor type has more static information than the result
36 ///
37 /// Example:
38 /// ```mlir
39 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
40 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
41 /// ```
42 ///
43 /// folds into:
44 ///
45 /// ```mlir
46 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
47 /// ```
48 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
49   if (!castOp)
50     return false;
51 
52   RankedTensorType sourceType =
53       castOp.source().getType().dyn_cast<RankedTensorType>();
54   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
55 
56   // Requires RankedTensorType.
57   if (!sourceType || !resultType)
58     return false;
59 
60   // Requires same elemental type.
61   if (sourceType.getElementType() != resultType.getElementType())
62     return false;
63 
64   // Requires same rank.
65   if (sourceType.getRank() != resultType.getRank())
66     return false;
67 
68   // If cast is towards more static sizes along any dimension, don't fold.
69   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
70     if (ShapedType::isDynamic(std::get<0>(t)) &&
71         !ShapedType::isDynamic(std::get<1>(t)))
72       return false;
73   }
74 
75   return true;
76 }
77 
78 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
79 /// that can be folded.
80 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
81   bool folded = false;
82   for (OpOperand &operand : op->getOpOperands()) {
83     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
84     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
85       operand.set(castOp.getOperand());
86       folded = true;
87     }
88   }
89   return success(folded);
90 }
91 
92 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
93   if (inputs.size() != 1 || outputs.size() != 1)
94     return false;
95   Type a = inputs.front(), b = outputs.front();
96   auto aT = a.dyn_cast<TensorType>();
97   auto bT = b.dyn_cast<TensorType>();
98   if (!aT || !bT)
99     return false;
100 
101   if (aT.getElementType() != bT.getElementType())
102     return false;
103 
104   return succeeded(verifyCompatibleShape(aT, bT));
105 }
106 
107 /// Compute a TensorType that has the joined shape knowledge of the two
108 /// given TensorTypes. The element types need to match.
109 static TensorType joinShapes(TensorType one, TensorType two) {
110   assert(one.getElementType() == two.getElementType());
111 
112   if (!one.hasRank())
113     return two;
114   if (!two.hasRank())
115     return one;
116 
117   int64_t rank = one.getRank();
118   if (rank != two.getRank())
119     return {};
120 
121   SmallVector<int64_t, 4> join;
122   join.reserve(rank);
123   for (int64_t i = 0; i < rank; ++i) {
124     if (one.isDynamicDim(i)) {
125       join.push_back(two.getDimSize(i));
126       continue;
127     }
128     if (two.isDynamicDim(i)) {
129       join.push_back(one.getDimSize(i));
130       continue;
131     }
132     if (one.getDimSize(i) != two.getDimSize(i))
133       return {};
134     join.push_back(one.getDimSize(i));
135   }
136   return RankedTensorType::get(join, one.getElementType());
137 }
138 
139 namespace {
140 
141 /// Replaces chains of two tensor.cast operations by a single tensor.cast
142 /// operation if doing so does not remove runtime constraints.
143 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
144   using OpRewritePattern<CastOp>::OpRewritePattern;
145 
146   LogicalResult matchAndRewrite(CastOp tensorCast,
147                                 PatternRewriter &rewriter) const final {
148     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
149 
150     if (!tensorCastOperand)
151       return failure();
152 
153     auto sourceType =
154         tensorCastOperand.getOperand().getType().cast<TensorType>();
155     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
156     auto resultType = tensorCast.getType().cast<TensorType>();
157 
158     // We can remove the intermediate cast if joining all three produces the
159     // same result as just joining the source and result shapes.
160     auto firstJoin =
161         joinShapes(joinShapes(sourceType, intermediateType), resultType);
162 
163     // The join might not exist if the cast sequence would fail at runtime.
164     if (!firstJoin)
165       return failure();
166 
167     // The newJoin always exists if the above join exists, it might just contain
168     // less information. If so, we cannot drop the intermediate cast, as doing
169     // so would remove runtime checks.
170     auto newJoin = joinShapes(sourceType, resultType);
171     if (firstJoin != newJoin)
172       return failure();
173 
174     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
175                                         tensorCastOperand.getOperand());
176     return success();
177   }
178 };
179 
180 } // namespace
181 
182 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
183                                          MLIRContext *context) {
184   results.add<ChainedTensorCast>(context);
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // ExtractOp
189 //===----------------------------------------------------------------------===//
190 
191 static LogicalResult verify(ExtractOp op) {
192   // Verify the # indices match if we have a ranked type.
193   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
194     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
195       return op.emitOpError("incorrect number of indices for extract_element");
196 
197   return success();
198 }
199 
200 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
201   // The tensor operand must be a known constant.
202   Attribute tensor = operands.front();
203   if (!tensor)
204     return {};
205   // If this is a splat elements attribute, simply return the value. All of the
206   // elements of a splat attribute are the same.
207   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
208     return splatTensor.getSplatValue();
209 
210   // Otherwise, collect the constant indices into the tensor.
211   SmallVector<uint64_t, 8> indices;
212   for (Attribute indice : llvm::drop_begin(operands, 1)) {
213     if (!indice || !indice.isa<IntegerAttr>())
214       return {};
215     indices.push_back(indice.cast<IntegerAttr>().getInt());
216   }
217 
218   // If this is an elements attribute, query the value at the given indices.
219   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
220   if (elementsAttr && elementsAttr.isValidIndex(indices))
221     return elementsAttr.getValue(indices);
222   return {};
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // FromElementsOp
227 //===----------------------------------------------------------------------===//
228 
229 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
230                            Type elementType, ValueRange elements) {
231   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
232                                         elementType);
233   result.addOperands(elements);
234   result.addTypes(resultTy);
235 }
236 
237 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
238                            ValueRange elements) {
239   assert(!elements.empty() && "expected at least one element");
240   build(builder, result, elements.front().getType(), elements);
241 }
242 
243 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
244   if (!llvm::is_contained(operands, nullptr))
245     return DenseElementsAttr::get(getType(), operands);
246   return {};
247 }
248 
249 namespace {
250 
251 // Canonicalizes the pattern of the form
252 //
253 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
254 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
255 //
256 // to just %element.
257 struct ExtractElementFromTensorFromElements
258     : public OpRewritePattern<tensor::ExtractOp> {
259   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
262                                 PatternRewriter &rewriter) const final {
263     if (extract.indices().size() != 1)
264       return failure();
265 
266     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
267     if (tensorFromElements == nullptr)
268       return failure();
269 
270     APInt index;
271     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
272       return failure();
273     // Prevent out of bounds accesses. This can happen in invalid code that will
274     // never execute.
275     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
276         index.getSExtValue() < 0)
277       return failure();
278     rewriter.replaceOp(extract,
279                        tensorFromElements.getOperand(index.getZExtValue()));
280     return success();
281   }
282 };
283 
284 } // namespace
285 
286 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
287                                                  MLIRContext *context) {
288   results.add<ExtractElementFromTensorFromElements>(context);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // InsertOp
293 //===----------------------------------------------------------------------===//
294 
295 static LogicalResult verify(InsertOp op) {
296   // Verify the # indices match if we have a ranked type.
297   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
298     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
299       return op.emitOpError("incorrect number of indices");
300   return success();
301 }
302 
303 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
304   Attribute scalar = operands[0];
305   Attribute dest = operands[1];
306   if (scalar && dest)
307     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
308       if (scalar == splatDest.getSplatValue())
309         return dest;
310   return {};
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // GenerateOp
315 //===----------------------------------------------------------------------===//
316 
317 static LogicalResult verify(GenerateOp op) {
318   // Ensure that the tensor type has as many dynamic dimensions as are specified
319   // by the operands.
320   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
321   if (op.getNumOperands() != resultTy.getNumDynamicDims())
322     return op.emitError("must have as many index operands as dynamic extents "
323                         "in the result type");
324 
325   // Ensure that region arguments span the index space.
326   if (!llvm::all_of(op.body().getArgumentTypes(),
327                     [](Type ty) { return ty.isIndex(); }))
328     return op.emitError("all body arguments must be index");
329   if (op.body().getNumArguments() != resultTy.getRank())
330     return op.emitError("must have one body argument per input dimension");
331 
332   // Ensure that the region yields an element of the right type.
333   auto yieldOp =
334       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
335   if (yieldOp.value().getType() != resultTy.getElementType())
336     return op.emitOpError(
337         "body must be terminated with a `yield` operation of the tensor "
338         "element type");
339 
340   return success();
341 }
342 
343 void GenerateOp::build(
344     OpBuilder &b, OperationState &result, Type resultTy,
345     ValueRange dynamicExtents,
346     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
347   build(b, result, resultTy, dynamicExtents);
348 
349   // Build and populate body.
350   OpBuilder::InsertionGuard guard(b);
351   Region *bodyRegion = result.regions.front().get();
352   auto rank = resultTy.cast<RankedTensorType>().getRank();
353   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
354   Block *bodyBlock =
355       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
356   bodyBuilder(b, result.location, bodyBlock->getArguments());
357 }
358 
359 namespace {
360 
361 /// Canonicalizes tensor.generate operations with a constant
362 /// operand into the equivalent operation with the operand expressed in the
363 /// result type, instead. We also insert a type cast to make sure that the
364 /// resulting IR is still well-typed.
365 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
366   using OpRewritePattern<GenerateOp>::OpRewritePattern;
367 
368   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
369                                 PatternRewriter &rewriter) const final {
370     auto resultType =
371         tensorFromElements.getResult().getType().cast<RankedTensorType>();
372 
373     if (resultType.hasStaticShape())
374       return failure();
375 
376     SmallVector<Value, 4> newOperands;
377     SmallVector<int64_t, 4> newShape;
378     auto operandsIt = tensorFromElements.dynamicExtents().begin();
379 
380     for (int64_t dim : resultType.getShape()) {
381       if (dim != RankedTensorType::kDynamicSize) {
382         newShape.push_back(dim);
383         continue;
384       }
385       APInt index;
386       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
387         newShape.push_back(RankedTensorType::kDynamicSize);
388         newOperands.push_back(*operandsIt++);
389         continue;
390       }
391       newShape.push_back(index.getSExtValue());
392       operandsIt++;
393     }
394 
395     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
396       return failure();
397 
398     auto loc = tensorFromElements.getLoc();
399     auto newOp = rewriter.create<GenerateOp>(
400         loc, RankedTensorType::get(newShape, resultType.getElementType()),
401         newOperands);
402     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
403                                 newOp.body().begin());
404     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
405                                                 newOp);
406     return success();
407   }
408 };
409 
410 /// Canonicalizes the pattern of the form
411 ///
412 /// %tensor = tensor.generate %x {
413 ///   ^bb0(%arg0: index):  // no predecessors
414 ///   <computation>
415 ///   yield %1 : index
416 /// } : tensor<?xindex>
417 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
418 ///
419 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
420 /// tensor.generate operation has no side-effects.
421 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
422   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
423 
424   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
425                                 PatternRewriter &rewriter) const final {
426     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
427     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
428       return failure();
429 
430     BlockAndValueMapping mapping;
431     Block *body = tensorFromElements.getBody();
432     mapping.map(body->getArguments(), extract.indices());
433     for (auto &op : body->without_terminator())
434       rewriter.clone(op, mapping);
435 
436     auto yield = cast<YieldOp>(body->getTerminator());
437 
438     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
439     return success();
440   }
441 };
442 
443 /// Canonicalizes the pattern of the form
444 ///
445 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
446 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
447 ///
448 /// to
449 ///
450 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
451 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
452   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
453 
454   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
455                                 PatternRewriter &rewriter) const final {
456     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
457     if (!tensorCast)
458       return failure();
459 
460     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
461                                                    extract.indices());
462     return success();
463   }
464 };
465 
466 } // namespace
467 
468 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
469                                              MLIRContext *context) {
470   // TODO: Move extract patterns to tensor::ExtractOp.
471   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
472               StaticTensorGenerate>(context);
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // ReshapeOp
477 //===----------------------------------------------------------------------===//
478 
479 static int64_t GetNumElements(ShapedType type) {
480   int64_t numElements = 1;
481   for (auto dim : type.getShape())
482     numElements *= dim;
483   return numElements;
484 }
485 
486 static LogicalResult verify(ReshapeOp op) {
487   TensorType operandType = op.source().getType().cast<TensorType>();
488   TensorType resultType = op.result().getType().cast<TensorType>();
489 
490   if (operandType.getElementType() != resultType.getElementType())
491     return op.emitOpError("element types of source and destination tensor "
492                           "types should be the same");
493 
494   int64_t shapeSize =
495       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
496   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
497   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
498 
499   if (resultRankedType) {
500     if (operandRankedType && resultRankedType.hasStaticShape() &&
501         operandRankedType.hasStaticShape()) {
502       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
503         return op.emitOpError("source and destination tensor should have the "
504                               "same number of elements");
505     }
506     if (shapeSize == TensorType::kDynamicSize)
507       return op.emitOpError("cannot use shape operand with dynamic length to "
508                             "reshape to statically-ranked tensor type");
509     if (shapeSize != resultRankedType.getRank())
510       return op.emitOpError(
511           "length of shape operand differs from the result's tensor rank");
512   }
513   return success();
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // ExtractSliceOp
518 //===----------------------------------------------------------------------===//
519 
520 /// An extract_slice op result type can be fully inferred from the source type
521 /// and the static representation of offsets, sizes and strides. Special
522 /// sentinels encode the dynamic case.
523 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
524                                      ArrayRef<int64_t> leadingStaticOffsets,
525                                      ArrayRef<int64_t> leadingStaticSizes,
526                                      ArrayRef<int64_t> leadingStaticStrides) {
527   // An extract_slice op may specify only a leading subset of offset/sizes/
528   // strides in which case we complete with offset=0, sizes from memref type and
529   // strides=1.
530   unsigned rank = sourceRankedTensorType.getRank();
531   assert(leadingStaticSizes.size() <= rank &&
532          "unexpected leadingStaticSizes overflow");
533   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
534   unsigned numTrailingSizes = rank - staticSizes.size();
535   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
536                                       numTrailingSizes));
537   return RankedTensorType::get(staticSizes,
538                                sourceRankedTensorType.getElementType());
539 }
540 
541 Type ExtractSliceOp::inferResultType(
542     RankedTensorType sourceRankedTensorType,
543     ArrayRef<OpFoldResult> leadingStaticOffsets,
544     ArrayRef<OpFoldResult> leadingStaticSizes,
545     ArrayRef<OpFoldResult> leadingStaticStrides) {
546   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
547   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
548   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
549                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
550   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
551                              ShapedType::kDynamicSize);
552   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
553                              staticStrides, ShapedType::kDynamicStrideOrOffset);
554   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
555                                          staticSizes, staticStrides);
556 }
557 
558 /// An extract_slice op result type can be fully inferred from the source type
559 /// and the static representation of offsets, sizes and strides. Special
560 /// sentinels encode the dynamic case.
561 Type ExtractSliceOp::inferRankReducedResultType(
562     unsigned resultRank, RankedTensorType sourceRankedTensorType,
563     ArrayRef<int64_t> leadingStaticOffsets,
564     ArrayRef<int64_t> leadingStaticSizes,
565     ArrayRef<int64_t> leadingStaticStrides) {
566   auto inferredType =
567       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
568                       leadingStaticSizes, leadingStaticStrides)
569           .cast<RankedTensorType>();
570   int rankDiff = inferredType.getRank() - resultRank;
571   if (rankDiff > 0) {
572     auto shape = inferredType.getShape();
573     llvm::SmallDenseSet<unsigned> dimsToProject;
574     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
575     SmallVector<int64_t> projectedShape;
576     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
577       if (!dimsToProject.contains(pos))
578         projectedShape.push_back(shape[pos]);
579     inferredType =
580         RankedTensorType::get(projectedShape, inferredType.getElementType());
581   }
582   return inferredType;
583 }
584 
585 Type ExtractSliceOp::inferRankReducedResultType(
586     unsigned resultRank, RankedTensorType sourceRankedTensorType,
587     ArrayRef<OpFoldResult> leadingStaticOffsets,
588     ArrayRef<OpFoldResult> leadingStaticSizes,
589     ArrayRef<OpFoldResult> leadingStaticStrides) {
590   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
591   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
592   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
593                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
594   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
595                              ShapedType::kDynamicSize);
596   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
597                              staticStrides, ShapedType::kDynamicStrideOrOffset);
598   return ExtractSliceOp::inferRankReducedResultType(
599       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
600       staticStrides);
601 }
602 
603 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
604 /// result type. If the type passed is nullptr, it is inferred.
605 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
606                            RankedTensorType resultType, Value source,
607                            ArrayRef<OpFoldResult> offsets,
608                            ArrayRef<OpFoldResult> sizes,
609                            ArrayRef<OpFoldResult> strides,
610                            ArrayRef<NamedAttribute> attrs) {
611   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
612   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
613   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
614                              ShapedType::kDynamicStrideOrOffset);
615   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
616                              ShapedType::kDynamicSize);
617   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
618                              ShapedType::kDynamicStrideOrOffset);
619   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
620   // Structuring implementation this way avoids duplication between builders.
621   if (!resultType) {
622     resultType =
623         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
624                                         staticSizes, staticStrides)
625             .cast<RankedTensorType>();
626   }
627   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
628         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
629         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
630   result.addAttributes(attrs);
631 }
632 
633 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
634 /// result type.
635 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
636                            ArrayRef<OpFoldResult> offsets,
637                            ArrayRef<OpFoldResult> sizes,
638                            ArrayRef<OpFoldResult> strides,
639                            ArrayRef<NamedAttribute> attrs) {
640   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
641 }
642 
643 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
644 /// type passed is nullptr, it is inferred.
645 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
646                            RankedTensorType resultType, Value source,
647                            ValueRange offsets, ValueRange sizes,
648                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
649   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
650       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
651   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
652       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
653   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
654       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
655   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
656 }
657 
658 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
659 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
660                            ValueRange offsets, ValueRange sizes,
661                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
662   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
663 }
664 
665 enum SliceVerificationResult {
666   Success,
667   RankTooLarge,
668   SizeMismatch,
669   ElemTypeMismatch,
670 };
671 
672 /// Checks if `original` Type type can be rank reduced to `reduced` type.
673 /// This function is slight variant of `is subsequence` algorithm where
674 /// not matching dimension must be 1.
675 static SliceVerificationResult
676 isRankReducedType(Type originalType, Type candidateReducedType,
677                   std::string *errMsg = nullptr) {
678   if (originalType == candidateReducedType)
679     return SliceVerificationResult::Success;
680   if (!originalType.isa<RankedTensorType>())
681     return SliceVerificationResult::Success;
682   if (originalType.isa<RankedTensorType>() &&
683       !candidateReducedType.isa<RankedTensorType>())
684     return SliceVerificationResult::Success;
685 
686   ShapedType originalShapedType = originalType.cast<ShapedType>();
687   ShapedType candidateReducedShapedType =
688       candidateReducedType.cast<ShapedType>();
689 
690   // Rank and size logic is valid for all ShapedTypes.
691   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
692   ArrayRef<int64_t> candidateReducedShape =
693       candidateReducedShapedType.getShape();
694   unsigned originalRank = originalShape.size(),
695            candidateReducedRank = candidateReducedShape.size();
696   if (candidateReducedRank > originalRank)
697     return SliceVerificationResult::RankTooLarge;
698 
699   auto optionalUnusedDimsMask =
700       computeRankReductionMask(originalShape, candidateReducedShape);
701 
702   // Sizes cannot be matched in case empty vector is returned.
703   if (!optionalUnusedDimsMask.hasValue())
704     return SliceVerificationResult::SizeMismatch;
705 
706   if (originalShapedType.getElementType() !=
707       candidateReducedShapedType.getElementType())
708     return SliceVerificationResult::ElemTypeMismatch;
709 
710   // We are done for the tensor case.
711   if (originalType.isa<RankedTensorType>())
712     return SliceVerificationResult::Success;
713 
714   return SliceVerificationResult::Success;
715 }
716 
717 template <typename OpTy>
718 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
719                                           OpTy op, Type expectedType,
720                                           StringRef errMsg = "") {
721   auto memrefType = expectedType.cast<ShapedType>();
722   switch (result) {
723   case SliceVerificationResult::Success:
724     return success();
725   case SliceVerificationResult::RankTooLarge:
726     return op.emitError("expected result rank to be smaller or equal to ")
727            << "the source rank. " << errMsg;
728   case SliceVerificationResult::SizeMismatch:
729     return op.emitError("expected result type to be ")
730            << expectedType
731            << " or a rank-reduced version. (mismatch of result sizes) "
732            << errMsg;
733   case SliceVerificationResult::ElemTypeMismatch:
734     return op.emitError("expected result element type to be ")
735            << memrefType.getElementType() << errMsg;
736   }
737   llvm_unreachable("unexpected extract_slice op verification result");
738 }
739 
740 /// Verifier for ExtractSliceOp.
741 static LogicalResult verify(ExtractSliceOp op) {
742   // Verify result type against inferred type.
743   auto expectedType = ExtractSliceOp::inferResultType(
744       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
745       extractFromI64ArrayAttr(op.static_sizes()),
746       extractFromI64ArrayAttr(op.static_strides()));
747   auto result = isRankReducedType(expectedType, op.getType());
748   return produceSliceErrorMsg(result, op, expectedType);
749 }
750 
751 /// Infer the canonical type of the result of an extract_slice op. Returns a
752 /// type with rank `resultRank` that is either the rank of the rank-reduced
753 /// type, or the non-rank-reduced type.
754 static RankedTensorType
755 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
756                             ArrayRef<OpFoldResult> mixedOffsets,
757                             ArrayRef<OpFoldResult> mixedSizes,
758                             ArrayRef<OpFoldResult> mixedStrides) {
759   auto resultType =
760       ExtractSliceOp::inferRankReducedResultType(
761           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
762           .cast<RankedTensorType>();
763   if (resultType.getRank() != resultRank) {
764     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
765                                                  mixedSizes, mixedStrides)
766                      .cast<RankedTensorType>();
767   }
768   return resultType;
769 }
770 
771 namespace {
772 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
773 /// This essentially pushes memref_cast past its consuming slice when
774 /// `canFoldIntoConsumerOp` is true.
775 ///
776 /// Example:
777 /// ```
778 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
779 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
780 ///   tensor<3x4xf32>
781 /// ```
782 /// is rewritten into:
783 /// ```
784 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
785 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
786 /// ```
787 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
788 public:
789   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
790 
791   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
792                                 PatternRewriter &rewriter) const override {
793     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
794     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
795           return matchPattern(operand, matchConstantIndex());
796         }))
797       return failure();
798 
799     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
800     if (!castOp)
801       return failure();
802 
803     if (!canFoldIntoConsumerOp(castOp))
804       return failure();
805 
806     /// Deduce the type of the result to use for the canonicalized operation.
807     RankedTensorType resultType = getCanonicalSliceResultType(
808         sliceOp.getType().getRank(), sliceOp.getSourceType(),
809         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
810         sliceOp.getMixedStrides());
811     Value newSlice = rewriter.create<ExtractSliceOp>(
812         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
813         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
814         sliceOp.static_sizes(), sliceOp.static_strides());
815     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
816                                                 newSlice);
817     return success();
818   }
819 };
820 } // namespace
821 
822 /// Return the canonical type of the result of an extract_slice op.
823 struct SliceReturnTypeCanonicalizer {
824   RankedTensorType operator()(ExtractSliceOp op,
825                               ArrayRef<OpFoldResult> mixedOffsets,
826                               ArrayRef<OpFoldResult> mixedSizes,
827                               ArrayRef<OpFoldResult> mixedStrides) {
828     return getCanonicalSliceResultType(op.getType().getRank(),
829                                        op.getSourceType(), mixedOffsets,
830                                        mixedSizes, mixedStrides);
831   }
832 };
833 
834 /// A canonicalizer wrapper to replace ExtractSliceOps.
835 struct SliceCanonicalizer {
836   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
837                   ExtractSliceOp newOp) {
838     Value replacement = newOp.getResult();
839     if (replacement.getType() != op.getType())
840       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
841                                                     replacement);
842     rewriter.replaceOp(op, replacement);
843   }
844 };
845 
846 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
847                                                  MLIRContext *context) {
848   results.add<
849       OpWithOffsetSizesAndStridesConstantArgumentFolder<
850           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
851       ExtractSliceOpCastFolder>(context);
852 }
853 
854 //
855 static LogicalResult
856 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
857                                            ShapedType shapedType) {
858   OpBuilder b(op.getContext());
859   for (OpFoldResult ofr : op.getMixedOffsets())
860     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
861       return failure();
862   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
863   // is appropriate.
864   auto shape = shapedType.getShape();
865   for (auto it : llvm::zip(op.getMixedSizes(), shape))
866     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
867       return failure();
868   for (OpFoldResult ofr : op.getMixedStrides())
869     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
870       return failure();
871   return success();
872 }
873 
874 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
875   if (getSourceType() == getType() &&
876       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
877     return this->source();
878   return OpFoldResult();
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // InsertSliceOp
883 //===----------------------------------------------------------------------===//
884 
885 // Build a InsertSliceOp with mixed static and dynamic entries.
886 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
887                           Value dest, ArrayRef<OpFoldResult> offsets,
888                           ArrayRef<OpFoldResult> sizes,
889                           ArrayRef<OpFoldResult> strides,
890                           ArrayRef<NamedAttribute> attrs) {
891   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
892   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
893   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
894                              ShapedType::kDynamicStrideOrOffset);
895   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
896                              ShapedType::kDynamicSize);
897   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
898                              ShapedType::kDynamicStrideOrOffset);
899   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
900         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
901         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
902   result.addAttributes(attrs);
903 }
904 
905 // Build a InsertSliceOp with dynamic entries.
906 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
907                           Value dest, ValueRange offsets, ValueRange sizes,
908                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
909   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
910       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
911   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
912       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
913   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
914       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
915   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
916 }
917 
918 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
919   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
920       getSourceType() == getType() &&
921       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
922     return this->source();
923   return OpFoldResult();
924 }
925 
926 namespace {
927 /// Pattern to rewrite a insert_slice op with constant arguments.
928 class InsertSliceOpConstantArgumentFolder final
929     : public OpRewritePattern<InsertSliceOp> {
930 public:
931   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
932 
933   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
934                                 PatternRewriter &rewriter) const override {
935     // No constant operand, just return.
936     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
937           return matchPattern(operand, matchConstantIndex());
938         }))
939       return failure();
940 
941     // At least one of offsets/sizes/strides is a new constant.
942     // Form the new list of operands and constant attributes from the
943     // existing.
944     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
945     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
946     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
947     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
948     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
949     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
950 
951     // Create the new op in canonical form.
952     rewriter.replaceOpWithNewOp<InsertSliceOp>(
953         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
954         mixedOffsets, mixedSizes, mixedStrides);
955     return success();
956   }
957 };
958 
959 /// Fold tensor_casts with insert_slice operations.
960 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
961   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
962 
963   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
964                                 PatternRewriter &rewriter) const override {
965     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
966           return matchPattern(operand, matchConstantIndex());
967         }))
968       return failure();
969 
970     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
971       auto castOp = v.getDefiningOp<tensor::CastOp>();
972       if (!castOp || !canFoldIntoConsumerOp(castOp))
973         return llvm::None;
974       return castOp.source();
975     };
976     Optional<Value> sourceCastSource =
977         getSourceOfCastOp(insertSliceOp.source());
978     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
979     if (!sourceCastSource && !destCastSource)
980       return failure();
981 
982     Value replacement = rewriter.create<InsertSliceOp>(
983         insertSliceOp.getLoc(),
984         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
985         (destCastSource ? *destCastSource : insertSliceOp.dest()),
986         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
987         insertSliceOp.getMixedStrides());
988 
989     if (replacement.getType() != insertSliceOp.getType()) {
990       replacement = rewriter.create<tensor::CastOp>(
991           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
992     }
993     rewriter.replaceOp(insertSliceOp, replacement);
994     return success();
995   }
996 };
997 } // namespace
998 
999 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1000                                                 MLIRContext *context) {
1001   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
1002       context);
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // TableGen'd op method definitions
1007 //===----------------------------------------------------------------------===//
1008 
1009 #define GET_OP_CLASSES
1010 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1011