1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
14 #include "mlir/Dialect/Utils/StaticValueUtils.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinAttributeInterfaces.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 
23 using namespace mlir;
24 using namespace mlir::tensor;
25 
26 /// Materialize a single constant operation from a given attribute value with
27 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)28 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
29                                               Attribute value, Type type,
30                                               Location loc) {
31   if (arith::ConstantOp::isBuildableWith(value, type))
32     return builder.create<arith::ConstantOp>(loc, value, type);
33   if (complex::ConstantOp::isBuildableWith(value, type))
34     return builder.create<complex::ConstantOp>(loc, type,
35                                                value.cast<ArrayAttr>());
36   return nullptr;
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // CastOp
41 //===----------------------------------------------------------------------===//
42 
43 /// Returns true if `target` is a ranked tensor type that preserves static
44 /// information available in the `source` ranked tensor type.
preservesStaticInformation(Type source,Type target)45 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
46   auto sourceType = source.dyn_cast<RankedTensorType>();
47   auto targetType = target.dyn_cast<RankedTensorType>();
48 
49   // Requires RankedTensorType.
50   if (!sourceType || !targetType)
51     return false;
52 
53   // Requires same elemental type.
54   if (sourceType.getElementType() != targetType.getElementType())
55     return false;
56 
57   // Requires same rank.
58   if (sourceType.getRank() != targetType.getRank())
59     return false;
60 
61   // If cast is towards more static sizes along any dimension, don't fold.
62   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
63     if (!ShapedType::isDynamic(std::get<0>(t)) &&
64         ShapedType::isDynamic(std::get<1>(t)))
65       return false;
66   }
67 
68   return true;
69 }
70 
71 /// Determines whether tensor::CastOp casts to a more dynamic version of the
72 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
73 /// implement canonicalization patterns for ops in different dialects that may
74 /// consume the results of tensor.cast operations. Such foldable tensor.cast
75 /// operations are typically inserted as `slice` ops and are canonicalized,
76 /// to preserve the type compatibility of their uses.
77 ///
78 /// Returns true when all conditions are met:
79 /// 1. source and result are ranked tensors with same element type and rank.
80 /// 2. the tensor type has more static information than the result
81 ///
82 /// Example:
83 /// ```mlir
84 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
85 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
86 /// ```
87 ///
88 /// folds into:
89 ///
90 /// ```mlir
91 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
92 /// ```
canFoldIntoConsumerOp(CastOp castOp)93 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
94   if (!castOp)
95     return false;
96 
97   // Can fold if the source of cast has at least as much static information as
98   // its results.
99   return preservesStaticInformation(castOp.getType(),
100                                     castOp.getSource().getType());
101 }
102 
103 /// Determines whether the tensor::CastOp casts to a more static version of the
104 /// source tensor. This is useful to fold into a producing op and implement
105 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
106 /// being from different dialects. Returns true when all conditions are met:
107 /// 1. source and result and ranked tensors with same element type and rank.
108 /// 2. the result type has more static information than the source.
109 ///
110 /// Example:
111 /// ```mlir
112 ///   %1 = producer ... : tensor<?x?xf32>
113 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
114 /// ```
115 ///
116 /// can be canonicalized to :
117 ///
118 /// ```mlir
119 ///   %2 = producer ... : tensor<8x16xf32>
120 /// ```
121 /// Not all ops might be canonicalizable this way, but for those that can be,
122 /// this method provides a check that it is worth doing the canonicalization.
canFoldIntoProducerOp(CastOp castOp)123 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
124   if (!castOp)
125     return false;
126   return preservesStaticInformation(castOp.getSource().getType(),
127                                     castOp.getType());
128 }
129 
130 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
131 /// that can be folded.
foldTensorCast(Operation * op)132 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
133   bool folded = false;
134   for (OpOperand &operand : op->getOpOperands()) {
135     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
136     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
137       operand.set(castOp.getOperand());
138       folded = true;
139     }
140   }
141   return success(folded);
142 }
143 
areCastCompatible(TypeRange inputs,TypeRange outputs)144 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
145   if (inputs.size() != 1 || outputs.size() != 1)
146     return false;
147   Type a = inputs.front(), b = outputs.front();
148   auto aT = a.dyn_cast<TensorType>();
149   auto bT = b.dyn_cast<TensorType>();
150   if (!aT || !bT)
151     return false;
152 
153   if (aT.getElementType() != bT.getElementType())
154     return false;
155 
156   return succeeded(verifyCompatibleShape(aT, bT));
157 }
158 
159 /// Compute a TensorType that has the joined shape knowledge of the two
160 /// given TensorTypes. The element types need to match.
joinShapes(TensorType one,TensorType two)161 static TensorType joinShapes(TensorType one, TensorType two) {
162   assert(one.getElementType() == two.getElementType());
163 
164   if (!one.hasRank())
165     return two;
166   if (!two.hasRank())
167     return one;
168 
169   int64_t rank = one.getRank();
170   if (rank != two.getRank())
171     return {};
172 
173   SmallVector<int64_t, 4> join;
174   join.reserve(rank);
175   for (int64_t i = 0; i < rank; ++i) {
176     if (one.isDynamicDim(i)) {
177       join.push_back(two.getDimSize(i));
178       continue;
179     }
180     if (two.isDynamicDim(i)) {
181       join.push_back(one.getDimSize(i));
182       continue;
183     }
184     if (one.getDimSize(i) != two.getDimSize(i))
185       return {};
186     join.push_back(one.getDimSize(i));
187   }
188   return RankedTensorType::get(join, one.getElementType());
189 }
190 
191 namespace {
192 
193 /// Replaces chains of two tensor.cast operations by a single tensor.cast
194 /// operation if doing so does not remove runtime constraints.
195 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
196   using OpRewritePattern<CastOp>::OpRewritePattern;
197 
matchAndRewrite__anon3fb9f79f0111::ChainedTensorCast198   LogicalResult matchAndRewrite(CastOp tensorCast,
199                                 PatternRewriter &rewriter) const final {
200     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
201 
202     if (!tensorCastOperand)
203       return failure();
204 
205     auto sourceType =
206         tensorCastOperand.getOperand().getType().cast<TensorType>();
207     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
208     auto resultType = tensorCast.getType().cast<TensorType>();
209 
210     // We can remove the intermediate cast if joining all three produces the
211     // same result as just joining the source and result shapes.
212     auto firstJoin =
213         joinShapes(joinShapes(sourceType, intermediateType), resultType);
214 
215     // The join might not exist if the cast sequence would fail at runtime.
216     if (!firstJoin)
217       return failure();
218 
219     // The newJoin always exists if the above join exists, it might just contain
220     // less information. If so, we cannot drop the intermediate cast, as doing
221     // so would remove runtime checks.
222     auto newJoin = joinShapes(sourceType, resultType);
223     if (firstJoin != newJoin)
224       return failure();
225 
226     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
227                                         tensorCastOperand.getOperand());
228     return success();
229   }
230 };
231 
232 /// Fold tensor.cast into tesor.extract_slice producer.
233 /// Example:
234 /// ```
235 ///  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
236 ///    tensor<128x512xf32> to tensor<?x512xf32>
237 ///  %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
238 /// ```
239 /// ->
240 /// ```
241 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
242 ///   tensor<128x512xf32> to tensor<16x512xf32>
243 /// ```
244 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
245   using OpRewritePattern<CastOp>::OpRewritePattern;
246 
matchAndRewrite__anon3fb9f79f0111::TensorCastExtractSlice247   LogicalResult matchAndRewrite(CastOp tensorCast,
248                                 PatternRewriter &rewriter) const final {
249     auto extractOperand =
250         tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
251 
252     if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
253         tensorCast.getType().getShape() == tensorCast.getSource()
254                                                .getType()
255                                                .cast<RankedTensorType>()
256                                                .getShape())
257       return failure();
258 
259     SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
260     auto dimMask = computeRankReductionMask(
261         extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
262         extractOperand.getType().getShape());
263     size_t dimIndex = 0;
264     for (size_t i = 0, e = sizes.size(); i < e; i++) {
265       if (dimMask && dimMask->count(i))
266         continue;
267       int64_t dim = tensorCast.getType().getShape()[dimIndex++];
268       if (ShapedType::isDynamic(dim))
269         continue;
270       sizes[i] = rewriter.getIndexAttr(dim);
271     }
272 
273     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
274         tensorCast, tensorCast.getType().cast<RankedTensorType>(),
275         extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
276         extractOperand.getMixedStrides());
277     return success();
278   }
279 };
280 
281 } // namespace
282 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)283 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
284                                          MLIRContext *context) {
285   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // DimOp
290 //===----------------------------------------------------------------------===//
291 
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)292 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
293                   int64_t index) {
294   auto loc = result.location;
295   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
296   build(builder, result, source, indexValue);
297 }
298 
getConstantIndex()299 Optional<int64_t> DimOp::getConstantIndex() {
300   if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
301     return constantOp.getValue().cast<IntegerAttr>().getInt();
302   return {};
303 }
304 
verify()305 LogicalResult DimOp::verify() {
306   // Assume unknown index to be in range.
307   Optional<int64_t> index = getConstantIndex();
308   if (!index)
309     return success();
310 
311   // Check that constant index is not knowingly out of range.
312   auto type = getSource().getType();
313   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
314     if (*index >= tensorType.getRank())
315       return emitOpError("index is out of range");
316   } else if (type.isa<UnrankedTensorType>()) {
317     // Assume index to be in range.
318   } else {
319     llvm_unreachable("expected operand with tensor type");
320   }
321   return success();
322 }
323 
fold(ArrayRef<Attribute> operands)324 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
325   // All forms of folding require a known index.
326   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
327   if (!index)
328     return {};
329 
330   // Folding for unranked types (UnrankedTensorType) is not supported.
331   auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
332   if (!tensorType)
333     return {};
334 
335   // Fold if the shape extent along the given index is known.
336   if (!tensorType.isDynamicDim(index.getInt())) {
337     Builder builder(getContext());
338     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
339   }
340 
341   Operation *definingOp = getSource().getDefiningOp();
342 
343   // Fold dim to the operand of tensor.generate.
344   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
345     auto resultType =
346         fromElements.getResult().getType().cast<RankedTensorType>();
347     // The case where the type encodes the size of the dimension is handled
348     // above.
349     assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
350 
351     // Find the operand of the fromElements that corresponds to this index.
352     auto dynExtents = fromElements.getDynamicExtents().begin();
353     for (auto dim : resultType.getShape().take_front(index.getInt()))
354       if (ShapedType::isDynamic(dim))
355         dynExtents++;
356 
357     return Value{*dynExtents};
358   }
359 
360   // The size at the given index is now known to be a dynamic size.
361   unsigned unsignedIndex = index.getValue().getZExtValue();
362 
363   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
364     // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
365     // `resolve-shaped-type-result-dims` pass.
366     if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
367         sliceOp.isDynamicSize(unsignedIndex)) {
368       return {sliceOp.getDynamicSize(unsignedIndex)};
369     }
370   }
371 
372   // dim(cast) -> dim
373   if (succeeded(foldTensorCast(*this)))
374     return getResult();
375 
376   return {};
377 }
378 
379 namespace {
380 /// Fold dim of a cast into the dim of the source of the tensor cast.
381 struct DimOfCastOp : public OpRewritePattern<DimOp> {
382   using OpRewritePattern<DimOp>::OpRewritePattern;
383 
matchAndRewrite__anon3fb9f79f0211::DimOfCastOp384   LogicalResult matchAndRewrite(DimOp dimOp,
385                                 PatternRewriter &rewriter) const override {
386     auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
387     if (!castOp)
388       return failure();
389     Value newSource = castOp.getOperand();
390     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
391     return success();
392   }
393 };
394 } // namespace
395 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)396 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
397                                         MLIRContext *context) {
398   results.add<DimOfCastOp>(context);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // ExtractOp
403 //===----------------------------------------------------------------------===//
404 
verify()405 LogicalResult ExtractOp::verify() {
406   // Verify the # indices match if we have a ranked type.
407   if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
408     if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
409       return emitOpError("incorrect number of indices for extract_element");
410 
411   return success();
412 }
413 
fold(ArrayRef<Attribute> operands)414 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
415   // If this is a splat elements attribute, simply return the value. All of the
416   // elements of a splat attribute are the same.
417   if (Attribute tensor = operands.front())
418     if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
419       return splatTensor.getSplatValue<Attribute>();
420 
421   // Collect the constant indices into the tensor.
422   SmallVector<uint64_t, 8> indices;
423   for (Attribute indice : llvm::drop_begin(operands, 1)) {
424     if (!indice || !indice.isa<IntegerAttr>())
425       return {};
426     indices.push_back(indice.cast<IntegerAttr>().getInt());
427   }
428 
429   // Fold extract(from_elements(...)).
430   if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
431     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
432     auto rank = tensorType.getRank();
433     assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
434            "rank mismatch");
435     int flatIndex = 0;
436     int stride = 1;
437     for (int i = rank - 1; i >= 0; --i) {
438       if (i < rank - 1)
439         stride *= tensorType.getDimSize(i);
440       flatIndex += indices[i] * stride;
441     }
442     // Prevent out of bounds accesses. This can happen in invalid code that will
443     // never execute.
444     if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
445         flatIndex < 0)
446       return {};
447     return fromElementsOp.getElements()[flatIndex];
448   }
449 
450   // If this is an elements attribute, query the value at the given indices.
451   if (Attribute tensor = operands.front()) {
452     auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
453     if (elementsAttr && elementsAttr.isValidIndex(indices))
454       return elementsAttr.getValues<Attribute>()[indices];
455   }
456 
457   return {};
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // FromElementsOp
462 //===----------------------------------------------------------------------===//
463 
build(OpBuilder & builder,OperationState & result,Type resultType,ValueRange elements)464 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
465                            Type resultType, ValueRange elements) {
466   result.addOperands(elements);
467   result.addTypes(resultType);
468 }
469 
build(OpBuilder & builder,OperationState & result,ValueRange elements)470 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
471                            ValueRange elements) {
472   assert(!elements.empty() && "expected at least one element");
473   Type resultType = RankedTensorType::get(
474       {static_cast<int64_t>(elements.size())}, elements.front().getType());
475   build(builder, result, resultType, elements);
476 }
477 
fold(ArrayRef<Attribute> operands)478 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
479   if (!llvm::is_contained(operands, nullptr))
480     return DenseElementsAttr::get(getType(), operands);
481   return {};
482 }
483 
484 namespace {
485 
486 // Pushes the index_casts that occur before extractions to after the extract.
487 // This minimizes type conversion in some cases and enables the extract
488 // canonicalizer. This changes:
489 //
490 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
491 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
492 //
493 // to the following:
494 //
495 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
496 // %cast = arith.index_cast %extract : i32 to index
497 //
498 // to just %element.
499 //
500 // Consider expanding this to a template and handle all tensor cast operations.
501 struct ExtractElementFromIndexCast
502     : public OpRewritePattern<tensor::ExtractOp> {
503   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
504 
matchAndRewrite__anon3fb9f79f0311::ExtractElementFromIndexCast505   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
506                                 PatternRewriter &rewriter) const final {
507     Location loc = extract.getLoc();
508     auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
509     if (!indexCast)
510       return failure();
511 
512     Type elementTy = getElementTypeOrSelf(indexCast.getIn());
513 
514     auto newExtract = rewriter.create<tensor::ExtractOp>(
515         loc, elementTy, indexCast.getIn(), extract.getIndices());
516 
517     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
518                                                     newExtract);
519 
520     return success();
521   }
522 };
523 
524 } // namespace
525 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)526 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
527                                                  MLIRContext *context) {
528   results.add<ExtractElementFromIndexCast>(context);
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // InsertOp
533 //===----------------------------------------------------------------------===//
534 
verify()535 LogicalResult InsertOp::verify() {
536   // Verify the # indices match if we have a ranked type.
537   if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
538     if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
539       return emitOpError("incorrect number of indices");
540   return success();
541 }
542 
fold(ArrayRef<Attribute> operands)543 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
544   Attribute scalar = operands[0];
545   Attribute dest = operands[1];
546   if (scalar && dest)
547     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
548       if (scalar == splatDest.getSplatValue<Attribute>())
549         return dest;
550   return {};
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // GenerateOp
555 //===----------------------------------------------------------------------===//
556 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)557 LogicalResult GenerateOp::reifyResultShapes(
558     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
559   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
560   int idx = 0;
561   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
562     if (getType().isDynamicDim(dim)) {
563       reifiedReturnShapes[0][dim] = getOperand(idx++);
564     } else {
565       reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
566           getLoc(), getType().getDimSize(dim));
567     }
568   }
569   return success();
570 }
571 
verify()572 LogicalResult GenerateOp::verify() {
573   // Ensure that the tensor type has as many dynamic dimensions as are specified
574   // by the operands.
575   RankedTensorType resultTy = getType().cast<RankedTensorType>();
576   if (getNumOperands() != resultTy.getNumDynamicDims())
577     return emitError("must have as many index operands as dynamic extents "
578                      "in the result type");
579 
580   return success();
581 }
582 
verifyRegions()583 LogicalResult GenerateOp::verifyRegions() {
584   RankedTensorType resultTy = getType().cast<RankedTensorType>();
585   // Ensure that region arguments span the index space.
586   if (!llvm::all_of(getBody().getArgumentTypes(),
587                     [](Type ty) { return ty.isIndex(); }))
588     return emitError("all body arguments must be index");
589   if (getBody().getNumArguments() != resultTy.getRank())
590     return emitError("must have one body argument per input dimension");
591 
592   // Ensure that the region yields an element of the right type.
593   auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
594 
595   if (yieldOp.getValue().getType() != resultTy.getElementType())
596     return emitOpError(
597         "body must be terminated with a `yield` operation of the tensor "
598         "element type");
599 
600   return success();
601 }
602 
build(OpBuilder & b,OperationState & result,Type resultTy,ValueRange dynamicExtents,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)603 void GenerateOp::build(
604     OpBuilder &b, OperationState &result, Type resultTy,
605     ValueRange dynamicExtents,
606     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
607   build(b, result, resultTy, dynamicExtents);
608 
609   // Build and populate body.
610   OpBuilder::InsertionGuard guard(b);
611   Region *bodyRegion = result.regions.front().get();
612   auto rank = resultTy.cast<RankedTensorType>().getRank();
613   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
614   SmallVector<Location, 2> argumentLocs(rank, result.location);
615   Block *bodyBlock =
616       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
617   bodyBuilder(b, result.location, bodyBlock->getArguments());
618 }
619 
620 namespace {
621 
622 /// Canonicalizes tensor.generate operations with a constant
623 /// operand into the equivalent operation with the operand expressed in the
624 /// result type, instead. We also insert a type cast to make sure that the
625 /// resulting IR is still well-typed.
626 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
627   using OpRewritePattern<GenerateOp>::OpRewritePattern;
628 
matchAndRewrite__anon3fb9f79f0511::StaticTensorGenerate629   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
630                                 PatternRewriter &rewriter) const final {
631     auto resultType =
632         tensorFromElements.getResult().getType().cast<RankedTensorType>();
633 
634     if (resultType.hasStaticShape())
635       return failure();
636 
637     SmallVector<Value, 4> newOperands;
638     SmallVector<int64_t, 4> newShape;
639     auto operandsIt = tensorFromElements.getDynamicExtents().begin();
640 
641     for (int64_t dim : resultType.getShape()) {
642       if (!ShapedType::isDynamic(dim)) {
643         newShape.push_back(dim);
644         continue;
645       }
646       APInt index;
647       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
648         newShape.push_back(ShapedType::kDynamicSize);
649         newOperands.push_back(*operandsIt++);
650         continue;
651       }
652       newShape.push_back(index.getSExtValue());
653       operandsIt++;
654     }
655 
656     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
657       return failure();
658 
659     auto loc = tensorFromElements.getLoc();
660     auto newOp = rewriter.create<GenerateOp>(
661         loc, RankedTensorType::get(newShape, resultType.getElementType()),
662         newOperands);
663     rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
664                                 newOp.getBody().begin());
665     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
666                                                 newOp);
667     return success();
668   }
669 };
670 
671 /// Canonicalizes the pattern of the form
672 ///
673 /// %tensor = tensor.generate %x {
674 ///   ^bb0(%arg0: index):
675 ///   <computation>
676 ///   yield %1 : index
677 /// } : tensor<?xindex>
678 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
679 ///
680 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
681 /// tensor.generate operation has no side-effects.
682 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
683   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
684 
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorGenerate685   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
686                                 PatternRewriter &rewriter) const final {
687     auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
688     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
689       return failure();
690 
691     BlockAndValueMapping mapping;
692     Block *body = &tensorFromElements.getBody().front();
693     mapping.map(body->getArguments(), extract.getIndices());
694     for (auto &op : body->without_terminator())
695       rewriter.clone(op, mapping);
696 
697     auto yield = cast<YieldOp>(body->getTerminator());
698 
699     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
700     return success();
701   }
702 };
703 
704 /// Canonicalizes the pattern of the form
705 ///
706 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
707 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
708 ///
709 /// to
710 ///
711 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
712 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
713   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
714 
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorCast715   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
716                                 PatternRewriter &rewriter) const final {
717     auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
718     if (!tensorCast)
719       return failure();
720 
721     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
722         extract, tensorCast.getSource(), extract.getIndices());
723     return success();
724   }
725 };
726 
727 } // namespace
728 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)729 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
730                                              MLIRContext *context) {
731   // TODO: Move extract patterns to tensor::ExtractOp.
732   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
733               StaticTensorGenerate>(context);
734 }
735 
736 //===----------------------------------------------------------------------===//
737 // RankOp
738 //===----------------------------------------------------------------------===//
739 
fold(ArrayRef<Attribute> operands)740 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
741   // Constant fold rank when the rank of the operand is known.
742   auto type = getOperand().getType();
743   auto shapedType = type.dyn_cast<ShapedType>();
744   if (shapedType && shapedType.hasRank())
745     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
746   return IntegerAttr();
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // ReshapeOp
751 //===----------------------------------------------------------------------===//
752 
getNumElements(ShapedType type)753 static int64_t getNumElements(ShapedType type) {
754   int64_t numElements = 1;
755   for (auto dim : type.getShape())
756     numElements *= dim;
757   return numElements;
758 }
759 
verify()760 LogicalResult ReshapeOp::verify() {
761   TensorType operandType = getSource().getType().cast<TensorType>();
762   TensorType resultType = getResult().getType().cast<TensorType>();
763 
764   if (operandType.getElementType() != resultType.getElementType())
765     return emitOpError("element types of source and destination tensor "
766                        "types should be the same");
767 
768   int64_t shapeSize =
769       getShape().getType().cast<RankedTensorType>().getDimSize(0);
770   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
771   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
772 
773   if (resultRankedType) {
774     if (operandRankedType && resultRankedType.hasStaticShape() &&
775         operandRankedType.hasStaticShape()) {
776       if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
777         return emitOpError("source and destination tensor should have the "
778                            "same number of elements");
779     }
780     if (ShapedType::isDynamic(shapeSize))
781       return emitOpError("cannot use shape operand with dynamic length to "
782                          "reshape to statically-ranked tensor type");
783     if (shapeSize != resultRankedType.getRank())
784       return emitOpError(
785           "length of shape operand differs from the result's tensor rank");
786   }
787   return success();
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // Reassociative reshape ops
792 //===----------------------------------------------------------------------===//
793 
getReassociationMaps()794 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
795   return getSymbolLessAffineMaps(getReassociationExprs());
796 }
getReassociationExprs()797 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
798   return convertReassociationIndicesToExprs(getContext(),
799                                             getReassociationIndices());
800 }
801 
getReassociationMaps()802 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
803   return getSymbolLessAffineMaps(getReassociationExprs());
804 }
getReassociationExprs()805 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
806   return convertReassociationIndicesToExprs(getContext(),
807                                             getReassociationIndices());
808 }
809 
810 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
811 static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,ArrayRef<AffineMap> reassociation)812 computeTensorReshapeCollapsedType(RankedTensorType type,
813                                   ArrayRef<AffineMap> reassociation) {
814   auto shape = type.getShape();
815   SmallVector<int64_t, 4> newShape;
816   newShape.reserve(reassociation.size());
817 
818   // Use the fact that reassociation is valid to simplify the logic: only use
819   // each map's rank.
820   assert(isReassociationValid(reassociation) && "invalid reassociation");
821   unsigned currentDim = 0;
822   for (AffineMap m : reassociation) {
823     unsigned dim = m.getNumResults();
824     auto band = shape.slice(currentDim, dim);
825     int64_t size = 1;
826     if (llvm::is_contained(band, ShapedType::kDynamicSize))
827       size = ShapedType::kDynamicSize;
828     else
829       for (unsigned d = 0; d < dim; ++d)
830         size *= shape[currentDim + d];
831     newShape.push_back(size);
832     currentDim += dim;
833   }
834 
835   return RankedTensorType::get(newShape, type.getElementType());
836 }
837 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)838 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
839                             ArrayRef<ReassociationIndices> reassociation,
840                             ArrayRef<NamedAttribute> attrs) {
841   auto resultType = computeTensorReshapeCollapsedType(
842       src.getType().cast<RankedTensorType>(),
843       getSymbolLessAffineMaps(
844           convertReassociationIndicesToExprs(b.getContext(), reassociation)));
845   build(b, result, resultType, src, attrs);
846   result.addAttribute(getReassociationAttrStrName(),
847                       getReassociationIndicesAttribute(b, reassociation));
848 }
849 
850 // Checks if types are the same, but ignoring encoding on ranked tensors.
isSameTypesWithoutEncoding(Type tp1,Type tp2)851 static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
852   if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
853     if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
854       return rtp1.getShape() == rtp2.getShape() &&
855              rtp1.getElementType() == rtp2.getElementType();
856     return false;
857   }
858   // Default implementation.
859   return tp1 == tp2;
860 }
861 
862 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
863                                         TensorReshapeOp, ExpandShapeOp>::value>
verifyTensorReshapeOp(TensorReshapeOp op,RankedTensorType expandedType,RankedTensorType collapsedType)864 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
865                                            RankedTensorType expandedType,
866                                            RankedTensorType collapsedType) {
867   if (failed(
868           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
869     return failure();
870 
871   auto maps = op.getReassociationMaps();
872   RankedTensorType expectedType =
873       computeTensorReshapeCollapsedType(expandedType, maps);
874   if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
875     return op.emitOpError("expected collapsed type to be ")
876            << expectedType << ", but got " << collapsedType;
877   return success();
878 }
879 
verify()880 LogicalResult ExpandShapeOp::verify() {
881   return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
882 }
883 
verify()884 LogicalResult CollapseShapeOp::verify() {
885   return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
886 }
887 
888 namespace {
889 /// Reshape of a splat constant can be replaced with a constant of the result
890 /// type.
891 template <typename TensorReshapeOp>
892 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
893   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithConstant894   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
895                                 PatternRewriter &rewriter) const override {
896     DenseElementsAttr attr;
897     if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
898       return failure();
899     if (!attr || !attr.isSplat())
900       return failure();
901     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
902         reshapeOp.getResultType(), attr.getRawData());
903     rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
904     return success();
905   }
906 };
907 
908 /// Reshape of a FromElements can be replaced with a FromElements of the result
909 /// type
910 template <typename TensorReshapeOp>
911 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
912   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithFromElements913   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
914                                 PatternRewriter &rewriter) const override {
915     auto fromElements =
916         reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
917     if (!fromElements)
918       return failure();
919 
920     auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
921 
922     if (!shapedTy.hasStaticShape())
923       return failure();
924 
925     rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
926                                                 fromElements.getElements());
927     return success();
928   }
929 };
930 
931 } // namespace
932 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)933 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
934                                                 MLIRContext *context) {
935   results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
936               ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
937               FoldReshapeWithConstant<ExpandShapeOp>,
938               FoldReshapeWithFromElements<ExpandShapeOp>>(context);
939 }
940 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)941 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
942                                                   MLIRContext *context) {
943   results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
944               ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
945               FoldReshapeWithConstant<CollapseShapeOp>,
946               FoldReshapeWithFromElements<CollapseShapeOp>>(context);
947 }
948 
fold(ArrayRef<Attribute> operands)949 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
950   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
951 }
fold(ArrayRef<Attribute> operands)952 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
953   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // ExtractSliceOp
958 //===----------------------------------------------------------------------===//
959 
960 /// An extract_slice result type can be inferred, when it is not
961 /// rank-reduced, from the source type and the static representation of
962 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)963 RankedTensorType ExtractSliceOp::inferResultType(
964     ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
965     ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
966   // An extract_slice op may specify only a leading subset of offset/sizes/
967   // strides in which case we complete with offset=0, sizes from memref type and
968   // strides=1.
969   assert(static_cast<int64_t>(staticSizes.size()) ==
970              sourceShapedTensorType.getRank() &&
971          "unexpected staticSizes not equal to rank of source");
972   return RankedTensorType::get(staticSizes,
973                                sourceShapedTensorType.getElementType());
974 }
975 
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)976 RankedTensorType ExtractSliceOp::inferResultType(
977     ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
978     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
979   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
980   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
981   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
982                              ShapedType::kDynamicStrideOrOffset);
983   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
984                              ShapedType::kDynamicSize);
985   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
986                              ShapedType::kDynamicStrideOrOffset);
987   return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
988                                          staticSizes, staticStrides);
989 }
990 
991 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
992 /// number of sizes), drop as many size 1 as needed to produce an inferred type
993 /// with the desired rank.
994 ///
995 /// Note that there may be multiple ways to compute this rank-reduced type:
996 ///   e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
997 ///
998 /// To disambiguate, this function always drops the first 1 sizes occurrences.
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)999 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1000     unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1001     ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1002     ArrayRef<int64_t> strides) {
1003   // Type inferred in the absence of rank-reducing behavior.
1004   auto inferredType =
1005       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1006           .cast<RankedTensorType>();
1007   int rankDiff = inferredType.getRank() - desiredResultRank;
1008   if (rankDiff > 0) {
1009     auto shape = inferredType.getShape();
1010     llvm::SmallBitVector dimsToProject =
1011         getPositionsOfShapeOne(rankDiff, shape);
1012     SmallVector<int64_t> projectedShape;
1013     // Best effort rank-reducing: drop 1s in order.
1014     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1015       if (!dimsToProject.test(pos))
1016         projectedShape.push_back(shape[pos]);
1017     inferredType =
1018         RankedTensorType::get(projectedShape, inferredType.getElementType());
1019   }
1020   return inferredType;
1021 }
1022 
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)1023 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1024     unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1025     ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1026     ArrayRef<OpFoldResult> strides) {
1027   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1028   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1029   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1030                              ShapedType::kDynamicStrideOrOffset);
1031   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1032                              ShapedType::kDynamicSize);
1033   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1034                              ShapedType::kDynamicStrideOrOffset);
1035   return ExtractSliceOp::inferCanonicalRankReducedResultType(
1036       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1037       staticStrides);
1038 }
1039 
1040 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1041 /// result type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1042 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1043                            RankedTensorType resultType, Value source,
1044                            ArrayRef<OpFoldResult> offsets,
1045                            ArrayRef<OpFoldResult> sizes,
1046                            ArrayRef<OpFoldResult> strides,
1047                            ArrayRef<NamedAttribute> attrs) {
1048   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1049   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1050   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1051                              ShapedType::kDynamicStrideOrOffset);
1052   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1053                              ShapedType::kDynamicSize);
1054   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1055                              ShapedType::kDynamicStrideOrOffset);
1056   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1057   // Structuring implementation this way avoids duplication between builders.
1058   if (!resultType) {
1059     resultType =
1060         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1061                                         staticSizes, staticStrides)
1062             .cast<RankedTensorType>();
1063   }
1064   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1065         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1066         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1067   result.addAttributes(attrs);
1068 }
1069 
1070 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1071 /// result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1072 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1073                            ArrayRef<OpFoldResult> offsets,
1074                            ArrayRef<OpFoldResult> sizes,
1075                            ArrayRef<OpFoldResult> strides,
1076                            ArrayRef<NamedAttribute> attrs) {
1077   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1078 }
1079 
1080 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
1081 /// type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1082 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1083                            RankedTensorType resultType, Value source,
1084                            ValueRange offsets, ValueRange sizes,
1085                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1086   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1087       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1088   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1089       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1090   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1091       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1092   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1093 }
1094 
1095 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1096 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1097                            ValueRange offsets, ValueRange sizes,
1098                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1099   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1100 }
1101 
1102 template <typename OpTy>
produceSliceErrorMsg(SliceVerificationResult result,OpTy op,Type expectedType)1103 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
1104                                           OpTy op, Type expectedType) {
1105   auto memrefType = expectedType.cast<ShapedType>();
1106   switch (result) {
1107   case SliceVerificationResult::Success:
1108     return success();
1109   case SliceVerificationResult::RankTooLarge:
1110     return op.emitError("expected rank to be smaller or equal to ")
1111            << "the other rank. ";
1112   case SliceVerificationResult::SizeMismatch:
1113     return op.emitError("expected type to be ")
1114            << expectedType << " or a rank-reduced version. (size mismatch) ";
1115   case SliceVerificationResult::ElemTypeMismatch:
1116     return op.emitError("expected element type to be ")
1117            << memrefType.getElementType();
1118   default:
1119     llvm_unreachable("unexpected extract_slice op verification result");
1120   }
1121 }
1122 
1123 /// Verifier for ExtractSliceOp.
verify()1124 LogicalResult ExtractSliceOp::verify() {
1125   // Verify result type against inferred type.
1126   RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1127       getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1128   SliceVerificationResult result = isRankReducedType(expectedType, getType());
1129   return produceSliceErrorMsg(result, *this, expectedType);
1130 }
1131 
getDroppedDims()1132 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1133   ArrayRef<int64_t> resultShape = getType().getShape();
1134   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
1135   llvm::SmallBitVector droppedDims(mixedSizes.size());
1136   unsigned shapePos = 0;
1137   for (const auto &size : enumerate(mixedSizes)) {
1138     Optional<int64_t> sizeVal = getConstantIntValue(size.value());
1139     // If the size is not 1, or if the current matched dimension of the result
1140     // is the same static shape as the size value (which is 1), then the
1141     // dimension is preserved.
1142     if (!sizeVal || *sizeVal != 1 ||
1143         (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
1144       shapePos++;
1145       continue;
1146     }
1147     droppedDims.set(size.index());
1148   }
1149   return droppedDims;
1150 }
1151 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)1152 LogicalResult ExtractSliceOp::reifyResultShapes(
1153     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1154   reifiedReturnShapes.resize(1);
1155   reifiedReturnShapes[0].reserve(getType().getRank());
1156   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
1157   llvm::SmallBitVector droppedDims = getDroppedDims();
1158   Location loc = getLoc();
1159   for (const auto &size : enumerate(mixedSizes)) {
1160     if (droppedDims.test(size.index()))
1161       continue;
1162     if (auto attr = size.value().dyn_cast<Attribute>()) {
1163       reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
1164           loc, attr.cast<IntegerAttr>().getInt()));
1165       continue;
1166     }
1167     reifiedReturnShapes[0].push_back(size.value().get<Value>());
1168   }
1169   return success();
1170 }
1171 
1172 namespace {
1173 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1174 /// This essentially pushes memref_cast past its consuming slice when
1175 /// `canFoldIntoConsumerOp` is true.
1176 ///
1177 /// Example:
1178 /// ```
1179 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1180 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1181 ///   tensor<3x4xf32>
1182 /// ```
1183 /// is rewritten into:
1184 /// ```
1185 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1186 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1187 /// ```
1188 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1189 public:
1190   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
1191 
matchAndRewrite(ExtractSliceOp sliceOp,PatternRewriter & rewriter) const1192   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1193                                 PatternRewriter &rewriter) const override {
1194     // Any constant operand, just return to let the constant folder kick in.
1195     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1196           return matchPattern(operand, matchConstantIndex());
1197         }))
1198       return failure();
1199 
1200     auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1201     if (!castOp)
1202       return failure();
1203 
1204     if (!canFoldIntoConsumerOp(castOp))
1205       return failure();
1206 
1207     /// Deduce the type of the result to use for the canonicalized operation.
1208     RankedTensorType resultType =
1209         ExtractSliceOp::inferCanonicalRankReducedResultType(
1210             sliceOp.getType().getRank(), sliceOp.getSourceType(),
1211             sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1212             sliceOp.getMixedStrides());
1213     Value newSlice = rewriter.create<ExtractSliceOp>(
1214         sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
1215         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1216         sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1217     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1218                                                 newSlice);
1219     return success();
1220   }
1221 };
1222 
1223 /// Slice elements from `values` into `outValues`. `counts` represents the
1224 /// numbers of elements to stride in the original values for each dimension.
1225 /// The output values can be used to construct a DenseElementsAttr.
1226 template <typename IterTy, typename ElemTy>
sliceElements(IterTy values,ArrayRef<int64_t> counts,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<ElemTy> * outValues)1227 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
1228                           ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1229                           ArrayRef<int64_t> strides,
1230                           llvm::SmallVectorImpl<ElemTy> *outValues) {
1231   assert(offsets.size() == sizes.size());
1232   assert(offsets.size() == strides.size());
1233   if (offsets.empty())
1234     return;
1235 
1236   int64_t offset = offsets.front();
1237   int64_t size = sizes.front();
1238   int64_t stride = strides.front();
1239   if (offsets.size() == 1) {
1240     for (int64_t i = 0; i < size; ++i, offset += stride)
1241       outValues->push_back(*(values + offset));
1242 
1243     return;
1244   }
1245 
1246   for (int64_t i = 0; i < size; ++i, offset += stride) {
1247     auto begin = values + offset * counts.front();
1248     sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1249                                   offsets.drop_front(), sizes.drop_front(),
1250                                   strides.drop_front(), outValues);
1251   }
1252 }
1253 
1254 /// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
1255 /// operation might introduce more constant data; Users can control their
1256 /// heuristics by the control function.
1257 class ConstantOpExtractSliceFolder final
1258     : public OpRewritePattern<ExtractSliceOp> {
1259 public:
1260   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
1261 
ConstantOpExtractSliceFolder(MLIRContext * context,ControlConstantExtractSliceFusionFn controlFn)1262   ConstantOpExtractSliceFolder(MLIRContext *context,
1263                                ControlConstantExtractSliceFusionFn controlFn)
1264       : OpRewritePattern<ExtractSliceOp>(context),
1265         controlFn(std::move(controlFn)) {}
1266 
matchAndRewrite(ExtractSliceOp op,PatternRewriter & rewriter) const1267   LogicalResult matchAndRewrite(ExtractSliceOp op,
1268                                 PatternRewriter &rewriter) const override {
1269     DenseElementsAttr attr;
1270     if (!matchPattern(op.getSource(), m_Constant(&attr)))
1271       return failure();
1272 
1273     // A constant splat is handled by fold().
1274     if (attr.isSplat())
1275       return failure();
1276 
1277     // Dynamic result shape is not supported.
1278     auto sourceType = op.getSource().getType().cast<ShapedType>();
1279     auto resultType = op.getResult().getType().cast<ShapedType>();
1280     if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1281       return failure();
1282 
1283     // Customized control over the folding.
1284     if (!controlFn(op))
1285       return failure();
1286 
1287     int64_t count = sourceType.getNumElements();
1288     if (count == 0)
1289       return failure();
1290 
1291     // Check if there are any dynamic parts, which are not supported.
1292     auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
1293     if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
1294       return failure();
1295     auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
1296     if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
1297       return failure();
1298     auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
1299     if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
1300       return failure();
1301 
1302     // Compute the stride for each dimension.
1303     SmallVector<int64_t> counts;
1304     ArrayRef<int64_t> shape = sourceType.getShape();
1305     counts.reserve(shape.size());
1306     for (int64_t v : shape) {
1307       count = count / v;
1308       counts.push_back(count);
1309     }
1310 
1311     // New attribute constructed by the sliced values.
1312     DenseElementsAttr newAttr;
1313 
1314     if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
1315       SmallVector<APInt> outValues;
1316       outValues.reserve(sourceType.getNumElements());
1317       sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1318           elems.begin(), counts, offsets, sizes, strides, &outValues);
1319       newAttr = DenseElementsAttr::get(resultType, outValues);
1320     } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
1321       SmallVector<APFloat> outValues;
1322       outValues.reserve(sourceType.getNumElements());
1323       sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1324           elems.begin(), counts, offsets, sizes, strides, &outValues);
1325       newAttr = DenseElementsAttr::get(resultType, outValues);
1326     }
1327 
1328     if (newAttr) {
1329       rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
1330       return success();
1331     }
1332 
1333     return failure();
1334   }
1335 
1336 private:
1337   /// This additionally controls whether the fold happens or not. Users can
1338   /// impose their heuristics in the function.
1339   ControlConstantExtractSliceFusionFn controlFn;
1340 };
1341 
1342 } // namespace
1343 
populateFoldConstantExtractSlicePatterns(RewritePatternSet & patterns,const ControlConstantExtractSliceFusionFn & controlFn)1344 void mlir::tensor::populateFoldConstantExtractSlicePatterns(
1345     RewritePatternSet &patterns,
1346     const ControlConstantExtractSliceFusionFn &controlFn) {
1347   patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
1348 }
1349 
1350 /// Return the canonical type of the result of an extract_slice op.
1351 struct SliceReturnTypeCanonicalizer {
operator ()SliceReturnTypeCanonicalizer1352   RankedTensorType operator()(ExtractSliceOp op,
1353                               ArrayRef<OpFoldResult> mixedOffsets,
1354                               ArrayRef<OpFoldResult> mixedSizes,
1355                               ArrayRef<OpFoldResult> mixedStrides) {
1356     return ExtractSliceOp::inferCanonicalRankReducedResultType(
1357         op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
1358         mixedStrides);
1359   }
1360 };
1361 
1362 /// A canonicalizer wrapper to replace ExtractSliceOps.
1363 struct SliceCanonicalizer {
operator ()SliceCanonicalizer1364   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1365                   ExtractSliceOp newOp) {
1366     Value replacement = newOp.getResult();
1367     if (replacement.getType() != op.getType())
1368       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1369                                                     replacement);
1370     rewriter.replaceOp(op, replacement);
1371   }
1372 };
1373 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1374 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375                                                  MLIRContext *context) {
1376   results.add<
1377       OpWithOffsetSizesAndStridesConstantArgumentFolder<
1378           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
1379       ExtractSliceOpCastFolder>(context);
1380 }
1381 
1382 //
1383 static LogicalResult
foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,ShapedType shapedType)1384 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1385                                            ShapedType shapedType) {
1386   OpBuilder b(op.getContext());
1387   for (OpFoldResult ofr : op.getMixedOffsets())
1388     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1389       return failure();
1390   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1391   // is appropriate.
1392   auto shape = shapedType.getShape();
1393   for (auto it : llvm::zip(op.getMixedSizes(), shape))
1394     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1395       return failure();
1396   for (OpFoldResult ofr : op.getMixedStrides())
1397     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1398       return failure();
1399   return success();
1400 }
1401 
1402 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1403 /// we can return the InsertSliceOp's source directly.
1404 // TODO: This only checks the immediate producer; extend to go up the
1405 // insert/extract chain if the slices are disjoint.
foldExtractAfterInsertSlice(ExtractSliceOp extractOp)1406 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
1407   auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
1408 
1409   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1410   if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
1411       insertOp.isSameAs(extractOp, isSame))
1412     return insertOp.getSource();
1413 
1414   return {};
1415 }
1416 
fold(ArrayRef<Attribute> operands)1417 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
1418   if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
1419     auto resultType = getResult().getType().cast<ShapedType>();
1420     if (resultType.hasStaticShape())
1421       return splat.resizeSplat(resultType);
1422   }
1423   if (getSourceType() == getType() &&
1424       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1425     return this->getSource();
1426   if (Value slice = foldExtractAfterInsertSlice(*this))
1427     return slice;
1428 
1429   return OpFoldResult();
1430 }
1431 
createCanonicalRankReducingExtractSliceOp(OpBuilder & b,Location loc,Value tensor,RankedTensorType targetType)1432 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
1433     OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
1434   auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
1435   unsigned rank = rankedTensorType.getRank();
1436   auto shape = rankedTensorType.getShape();
1437   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1438   SmallVector<OpFoldResult> sizes;
1439   for (unsigned i = 0, e = rank; i < e; ++i) {
1440     OpFoldResult dim;
1441     if (rankedTensorType.isDynamicDim(i))
1442       dim = b.createOrFold<tensor::DimOp>(
1443           loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
1444     else
1445       dim = b.getIndexAttr(shape[i]);
1446     sizes.push_back(dim);
1447   }
1448   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1449   return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1450                                                 offsets, sizes, strides);
1451 }
1452 
1453 //===----------------------------------------------------------------------===//
1454 // InsertSliceOp
1455 //===----------------------------------------------------------------------===//
1456 
1457 // Build a InsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1458 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1459                           Value dest, ArrayRef<OpFoldResult> offsets,
1460                           ArrayRef<OpFoldResult> sizes,
1461                           ArrayRef<OpFoldResult> strides,
1462                           ArrayRef<NamedAttribute> attrs) {
1463   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1464   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1465   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1466                              ShapedType::kDynamicStrideOrOffset);
1467   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1468                              ShapedType::kDynamicSize);
1469   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1470                              ShapedType::kDynamicStrideOrOffset);
1471   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1472         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1473         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1474   result.addAttributes(attrs);
1475 }
1476 
1477 // Build a InsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1478 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1479                           Value dest, ValueRange offsets, ValueRange sizes,
1480                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1481   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1482       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1483   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1484       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1485   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1486       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1487   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1488 }
1489 
1490 /// Rank-reducing type verification for both InsertSliceOp and
1491 /// ParallelInsertSliceOp.
1492 static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType,ShapedType dstType,ArrayAttr staticOffsets,ArrayAttr staticSizes,ArrayAttr staticStrides,ShapedType * expectedType=nullptr)1493 verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
1494                     ArrayAttr staticOffsets, ArrayAttr staticSizes,
1495                     ArrayAttr staticStrides,
1496                     ShapedType *expectedType = nullptr) {
1497   // insert_slice is the inverse of extract_slice, use the same type inference.
1498   RankedTensorType expected = ExtractSliceOp::inferResultType(
1499       dstType, extractFromI64ArrayAttr(staticOffsets),
1500       extractFromI64ArrayAttr(staticSizes),
1501       extractFromI64ArrayAttr(staticStrides));
1502   if (expectedType)
1503     *expectedType = expected;
1504   return isRankReducedType(expected, srcType);
1505 }
1506 
1507 /// Verifier for InsertSliceOp.
verify()1508 LogicalResult InsertSliceOp::verify() {
1509   ShapedType expectedType;
1510   SliceVerificationResult result =
1511       verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
1512                           getStaticSizes(), getStaticStrides(), &expectedType);
1513   return produceSliceErrorMsg(result, *this, expectedType);
1514 }
1515 
1516 /// If we have two consecutive InsertSliceOp writing to the same slice, we
1517 /// can mutate the second InsertSliceOp's destination to the first one's.
1518 /// This works similarly when the second op is a ParallelInsertSliceOp.
1519 ///
1520 /// Example:
1521 ///
1522 /// ```mlir
1523 ///   %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1524 ///   %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1525 /// ```
1526 ///
1527 /// folds into:
1528 ///
1529 /// ```mlir
1530 ///   %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1531 /// ```
1532 ///
1533 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1534 template <typename InsertOpTy>
foldInsertAfterInsertSlice(InsertOpTy insertOp)1535 static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
1536   auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
1537 
1538   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1539   if (!prevInsertOp ||
1540       prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
1541       !prevInsertOp.isSameAs(insertOp, isSame))
1542     return failure();
1543 
1544   insertOp.getDestMutable().assign(prevInsertOp.getDest());
1545   return success();
1546 }
1547 
1548 /// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
1549 /// type varies though so we wrap it in a FailureOr.
1550 ///
1551 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1552 template <typename InsertOpTy>
foldInsertOp(InsertOpTy insertOp,ArrayRef<Attribute>)1553 FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
1554   if (insertOp.getSourceType().hasStaticShape() &&
1555       insertOp.getDestType().hasStaticShape() &&
1556       insertOp.getSourceType() == insertOp.getDestType() &&
1557       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
1558           insertOp, insertOp.getDestType())))
1559     return static_cast<OpFoldResult>(insertOp.getSource());
1560   if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
1561     // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
1562     // return OpFoldResult().
1563     if (std::is_same<InsertOpTy, InsertSliceOp>::value)
1564       return static_cast<OpFoldResult>(insertOp->getResult(0));
1565     else
1566       return OpFoldResult();
1567   }
1568   return failure();
1569 }
1570 
fold(ArrayRef<Attribute> operands)1571 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
1572   auto maybeOpFoldResult = foldInsertOp(*this, operands);
1573   return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
1574 }
1575 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)1576 LogicalResult InsertSliceOp::reifyResultShapes(
1577     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1578   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1579   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1580     reifiedReturnShapes[0][dim] =
1581         builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
1582   }
1583   return success();
1584 }
1585 
1586 namespace {
1587 /// Pattern to rewrite a insert_slice op with constant arguments.
1588 ///
1589 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1590 template <typename InsertOpTy>
1591 class InsertSliceOpConstantArgumentFolder final
1592     : public OpRewritePattern<InsertOpTy> {
1593 public:
1594   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1595 
matchAndRewrite(InsertOpTy insertSliceOp,PatternRewriter & rewriter) const1596   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1597                                 PatternRewriter &rewriter) const override {
1598     // No constant operand, just return.
1599     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1600           return matchPattern(operand, matchConstantIndex());
1601         }))
1602       return failure();
1603 
1604     // At least one of offsets/sizes/strides is a new constant.
1605     // Form the new list of operands and constant attributes from the
1606     // existing.
1607     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1608     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1609     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1610     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1611     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1612     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1613 
1614     // Create the new op in canonical form.
1615     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
1616         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
1617         mixedOffsets, mixedSizes, mixedStrides);
1618     Value toInsert = insertSliceOp.getSource();
1619     if (sourceType != insertSliceOp.getSourceType()) {
1620       OpBuilder::InsertionGuard g(rewriter);
1621       // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1622       // the the insertion point is just before the ParallelCombiningOp in the
1623       // parallel case.
1624       if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1625         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1626       toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1627                                                  sourceType, toInsert);
1628     }
1629     rewriter.replaceOpWithNewOp<InsertOpTy>(
1630         insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
1631         mixedSizes, mixedStrides);
1632     return success();
1633   }
1634 };
1635 
1636 /// Fold tensor_casts with insert_slice operations. If the source or destination
1637 /// tensor is a tensor_cast that removes static type information, the cast is
1638 /// folded into the insert_slice operation. E.g.:
1639 ///
1640 /// ```mlir
1641 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1642 ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1643 /// ```
1644 ///
1645 /// folds into:
1646 ///
1647 /// ```mlir
1648 ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1649 /// ```
1650 ///
1651 /// Note: When folding a cast on the destination tensor, the result of the
1652 /// insert_slice operation is casted to ensure that the type of the result did
1653 /// not change.
1654 ///
1655 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1656 template <typename InsertOpTy>
1657 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
1658   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1659 
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpCastFolder1660   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1661                                 PatternRewriter &rewriter) const override {
1662     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1663           return matchPattern(operand, matchConstantIndex());
1664         }))
1665       return failure();
1666 
1667     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1668       auto castOp = v.getDefiningOp<tensor::CastOp>();
1669       if (!castOp || !canFoldIntoConsumerOp(castOp))
1670         return llvm::None;
1671       return castOp.getSource();
1672     };
1673     Optional<Value> sourceCastSource =
1674         getSourceOfCastOp(insertSliceOp.getSource());
1675     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
1676     if (!sourceCastSource && !destCastSource)
1677       return failure();
1678 
1679     auto src =
1680         (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
1681     auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
1682     auto srcType = src.getType().template cast<ShapedType>();
1683     auto dstType = dst.getType().template cast<ShapedType>();
1684     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
1685                             insertSliceOp.getStaticSizes(),
1686                             insertSliceOp.getStaticStrides()) !=
1687         SliceVerificationResult::Success)
1688       return failure();
1689 
1690     Operation *replacement = rewriter.create<InsertOpTy>(
1691         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1692         insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
1693 
1694     // In the parallel case there is no result and so nothing to cast.
1695     bool isParallelInsert =
1696         std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
1697     if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
1698       replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1699                                                     insertSliceOp.getDestType(),
1700                                                     replacement->getResult(0));
1701     }
1702     rewriter.replaceOp(insertSliceOp, replacement->getResults());
1703     return success();
1704   }
1705 };
1706 
1707 /// If additional static type information can be deduced from a insert_slice's
1708 /// size operands, insert an explicit cast of the op's source operand. This
1709 /// enables other canonicalization patterns that are matching for tensor_cast
1710 /// ops such as `ForOpTensorCastFolder` in SCF.
1711 ///
1712 /// Example:
1713 ///
1714 /// ```mlir
1715 ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1716 ///       : tensor<?x?xf32> into ...
1717 /// ```
1718 ///
1719 /// folds into:
1720 ///
1721 /// ```mlir
1722 ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1723 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1724 ///       : tensor<64x64xf32> into ...
1725 /// ```
1726 ///
1727 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
1728 template <typename InsertOpTy>
1729 struct InsertSliceOpSourceCastInserter final
1730     : public OpRewritePattern<InsertOpTy> {
1731   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1732 
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpSourceCastInserter1733   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1734                                 PatternRewriter &rewriter) const override {
1735     RankedTensorType srcType = insertSliceOp.getSourceType();
1736     if (srcType.getRank() != insertSliceOp.getDestType().getRank())
1737       return failure();
1738     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1739                                      srcType.getShape().end());
1740     for (int64_t i = 0; i < srcType.getRank(); ++i) {
1741       if (Optional<int64_t> constInt =
1742               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1743         newSrcShape[i] = *constInt;
1744     }
1745 
1746     RankedTensorType newSrcType =
1747         RankedTensorType::get(newSrcShape, srcType.getElementType());
1748     if (srcType == newSrcType ||
1749         !preservesStaticInformation(srcType, newSrcType) ||
1750         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1751       return failure();
1752 
1753     // newSrcType is:
1754     //   1) Different from srcType.
1755     //   2) "More static" than srcType.
1756     //   3) Cast-compatible with srcType.
1757     // Insert the cast.
1758     OpBuilder::InsertionGuard g(rewriter);
1759     // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1760     // the the insertion point is just before the ParallelCombiningOp in the
1761     // parallel case.
1762     if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1763       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1764     Value cast = rewriter.create<tensor::CastOp>(
1765         insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
1766     rewriter.replaceOpWithNewOp<InsertOpTy>(
1767         insertSliceOp, cast, insertSliceOp.getDest(),
1768         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1769         insertSliceOp.getMixedStrides());
1770     cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
1771     return success();
1772   }
1773 };
1774 } // namespace
1775 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1776 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1777                                                 MLIRContext *context) {
1778   results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
1779               InsertSliceOpCastFolder<InsertSliceOp>,
1780               InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
1781 }
1782 
createCanonicalRankReducingInsertSliceOp(OpBuilder & b,Location loc,Value tensor,Value dest)1783 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
1784                                                              Location loc,
1785                                                              Value tensor,
1786                                                              Value dest) {
1787   auto rankedTensorType = dest.getType().cast<RankedTensorType>();
1788   unsigned rank = rankedTensorType.getRank();
1789   auto shape = rankedTensorType.getShape();
1790   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1791   SmallVector<OpFoldResult> sizes;
1792   for (unsigned i = 0, e = rank; i < e; ++i) {
1793     OpFoldResult dim;
1794     if (rankedTensorType.isDynamicDim(i))
1795       dim = b.createOrFold<tensor::DimOp>(
1796           loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
1797     else
1798       dim = b.getIndexAttr(shape[i]);
1799     sizes.push_back(dim);
1800   }
1801   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1802   return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1803                                                sizes, strides);
1804 }
1805 
1806 //===----------------------------------------------------------------------===//
1807 // PadOp
1808 //===----------------------------------------------------------------------===//
1809 
1810 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1811 // supports optional types.
printInferType(OpAsmPrinter & printer,Operation * op,Value optOperand,Type typeToInfer,Type typeToInferFrom)1812 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1813                     Type typeToInfer, Type typeToInferFrom) {}
1814 
parseInferType(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> optOperand,Type & typeToInfer,Type typeToInferFrom)1815 ParseResult parseInferType(OpAsmParser &parser,
1816                            Optional<OpAsmParser::UnresolvedOperand> optOperand,
1817                            Type &typeToInfer, Type typeToInferFrom) {
1818   if (optOperand)
1819     typeToInfer = typeToInferFrom;
1820   return success();
1821 }
1822 
verify()1823 LogicalResult PadOp::verify() {
1824   auto sourceType = getSource().getType().cast<RankedTensorType>();
1825   auto resultType = getResult().getType().cast<RankedTensorType>();
1826   auto expectedType = PadOp::inferResultType(
1827       sourceType, extractFromI64ArrayAttr(getStaticLow()),
1828       extractFromI64ArrayAttr(getStaticHigh()));
1829   for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1830     if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1831       continue;
1832     if (expectedType.isDynamicDim(i))
1833       continue;
1834     return emitError("specified type ")
1835            << resultType << " does not match the inferred type "
1836            << expectedType;
1837   }
1838 
1839   return success();
1840 }
1841 
verifyRegions()1842 LogicalResult PadOp::verifyRegions() {
1843   auto &region = getRegion();
1844   unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
1845   Block &block = region.front();
1846   if (block.getNumArguments() != rank)
1847     return emitError("expected the block to have ") << rank << " arguments";
1848 
1849   // Note: the number and type of yield values are checked in the YieldOp.
1850   for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1851     if (!en.value().isIndex())
1852       return emitOpError("expected block argument ")
1853              << (en.index() + 1) << " to be an index";
1854   }
1855 
1856   // Ensure that the region yields an element of the right type.
1857   auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
1858   if (yieldOp.getValue().getType() !=
1859       getType().cast<ShapedType>().getElementType())
1860     return emitOpError("expected yield type to match shape element type");
1861 
1862   return success();
1863 }
1864 
inferResultType(RankedTensorType sourceType,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ArrayRef<int64_t> resultShape)1865 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1866                                         ArrayRef<int64_t> staticLow,
1867                                         ArrayRef<int64_t> staticHigh,
1868                                         ArrayRef<int64_t> resultShape) {
1869   unsigned rank = sourceType.getRank();
1870   assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1871   assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1872   assert((resultShape.empty() || resultShape.size() == rank) &&
1873          "unexpected resultShape size mismatch");
1874 
1875   SmallVector<int64_t, 4> inferredShape;
1876   for (auto i : llvm::seq<unsigned>(0, rank)) {
1877     if (sourceType.isDynamicDim(i) ||
1878         staticLow[i] == ShapedType::kDynamicSize ||
1879         staticHigh[i] == ShapedType::kDynamicSize) {
1880       inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1881                                                   : resultShape[i]);
1882     } else {
1883       int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1884       assert((resultShape.empty() || size == resultShape[i] ||
1885               resultShape[i] == ShapedType::kDynamicSize) &&
1886              "mismatch between inferred shape and result shape");
1887       inferredShape.push_back(size);
1888     }
1889   }
1890 
1891   return RankedTensorType::get(inferredShape, sourceType.getElementType());
1892 }
1893 
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1894 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1895                   ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
1896                   ValueRange low, ValueRange high, bool nofold,
1897                   ArrayRef<NamedAttribute> attrs) {
1898   auto sourceType = source.getType().cast<RankedTensorType>();
1899   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1900   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1901         b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1902   result.addAttributes(attrs);
1903 }
1904 
build(OpBuilder & b,OperationState & result,Value source,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1905 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1906                   ValueRange low, ValueRange high, bool nofold,
1907                   ArrayRef<NamedAttribute> attrs) {
1908   auto sourceType = source.getType().cast<RankedTensorType>();
1909   unsigned rank = sourceType.getRank();
1910   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1911   build(b, result, source, staticVector, staticVector, low, high, nofold,
1912         attrs);
1913 }
1914 
build(OpBuilder & b,OperationState & result,Type resultType,Value source,ArrayRef<OpFoldResult> low,ArrayRef<OpFoldResult> high,bool nofold,ArrayRef<NamedAttribute> attrs)1915 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1916                   Value source, ArrayRef<OpFoldResult> low,
1917                   ArrayRef<OpFoldResult> high, bool nofold,
1918                   ArrayRef<NamedAttribute> attrs) {
1919   assert(resultType.isa<RankedTensorType>());
1920   auto sourceType = source.getType().cast<RankedTensorType>();
1921   SmallVector<Value, 4> dynamicLow, dynamicHigh;
1922   SmallVector<int64_t, 4> staticLow, staticHigh;
1923   // staticLow and staticHigh have full information of the padding config.
1924   // This will grow staticLow and staticHigh with 1 value. If the config is
1925   // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1926   // value as well.
1927   dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1928                              ShapedType::kDynamicSize);
1929   dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1930                              ShapedType::kDynamicSize);
1931   if (!resultType) {
1932     resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1933   }
1934   build(b, result, resultType, source, dynamicLow, dynamicHigh,
1935         b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1936         nofold ? b.getUnitAttr() : UnitAttr());
1937   result.addAttributes(attrs);
1938 }
1939 
getPaddedDims()1940 llvm::SmallBitVector PadOp::getPaddedDims() {
1941   llvm::SmallBitVector paddedDims(getSourceType().getRank());
1942   auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1943     for (const auto &en : enumerate(paddingWidths))
1944       if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
1945         paddedDims.set(en.index());
1946   };
1947   extractPaddedDims(getMixedLowPad());
1948   extractPaddedDims(getMixedHighPad());
1949   return paddedDims;
1950 }
1951 
1952 namespace {
1953 // Folds tensor.pad when padding is static zeros and the attribute
1954 // doesn't request otherwise.
1955 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
1956   using OpRewritePattern<PadOp>::OpRewritePattern;
1957 
matchAndRewrite__anon3fb9f79f1611::FoldStaticZeroPadding1958   LogicalResult matchAndRewrite(PadOp padTensorOp,
1959                                 PatternRewriter &rewriter) const override {
1960     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1961       return failure();
1962     if (padTensorOp.getNofold())
1963       return failure();
1964     rewriter.replaceOpWithNewOp<tensor::CastOp>(
1965         padTensorOp, padTensorOp.getResult().getType(),
1966         padTensorOp.getSource());
1967     return success();
1968   }
1969 };
1970 
1971 // Fold CastOp into PadOp when adding static information.
1972 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
1973   using OpRewritePattern<PadOp>::OpRewritePattern;
1974 
matchAndRewrite__anon3fb9f79f1611::FoldSourceTensorCast1975   LogicalResult matchAndRewrite(PadOp padTensorOp,
1976                                 PatternRewriter &rewriter) const override {
1977     auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
1978     if (!tensor::canFoldIntoConsumerOp(castOp))
1979       return failure();
1980 
1981     auto newResultType = PadOp::inferResultType(
1982         castOp.getSource().getType().cast<RankedTensorType>(),
1983         extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
1984         extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
1985         padTensorOp.getResultType().getShape());
1986 
1987     if (newResultType == padTensorOp.getResultType()) {
1988       rewriter.updateRootInPlace(padTensorOp, [&]() {
1989         padTensorOp.getSourceMutable().assign(castOp.getSource());
1990       });
1991     } else {
1992       auto newOp = rewriter.create<PadOp>(
1993           padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
1994           padTensorOp.getLow(), padTensorOp.getHigh(),
1995           padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
1996           padTensorOp.getNofold());
1997       BlockAndValueMapping mapper;
1998       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1999 
2000       rewriter.replaceOpWithNewOp<tensor::CastOp>(
2001           padTensorOp, padTensorOp.getResultType(), newOp);
2002     }
2003     return success();
2004   }
2005 };
2006 
2007 // Fold CastOp using the result of PadOp back into the latter if it adds
2008 // static information.
2009 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2010   using OpRewritePattern<PadOp>::OpRewritePattern;
2011 
matchAndRewrite__anon3fb9f79f1611::FoldTargetTensorCast2012   LogicalResult matchAndRewrite(PadOp padTensorOp,
2013                                 PatternRewriter &rewriter) const override {
2014     if (!padTensorOp.getResult().hasOneUse())
2015       return failure();
2016     auto tensorCastOp =
2017         dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2018     if (!tensorCastOp)
2019       return failure();
2020     if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2021                                             tensorCastOp.getDest().getType()))
2022       return failure();
2023 
2024     auto replacementOp = rewriter.create<PadOp>(
2025         padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2026         padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
2027         padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2028         padTensorOp.getNofold());
2029     replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2030 
2031     rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2032     rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2033     return success();
2034   }
2035 };
2036 
2037 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2038 /// different dimensions. The pattern applies if the following preconditions
2039 /// hold:
2040 ///   1) the tensor::ExtractSliceOps are not rank-reducing,
2041 ///   2) the tensor::ExtractSliceOps have only unit-strides,
2042 ///   3) the tensor::PadOps perform only high-padding,
2043 ///   4) the tensor::PadOps have the same constant padding value,
2044 ///   5) the tensor::PadOps do not have common padding dimensions,
2045 ///   6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2046 ///      zero-offset for every dimension.
2047 ///   7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
2048 ///      padded source dimensions.
2049 ///
2050 /// Example:
2051 ///
2052 /// ```mlir
2053 ///   %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2054 ///       : tensor<64x64xf32> to tensor<?x64xf32>
2055 ///   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2056 ///     } : tensor<?x64xf32> to tensor<8x64xf32>
2057 ///   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2058 ///        : tensor<8x64xf32> to tensor<8x?xf32>
2059 ///   %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2060 ///     } : tensor<8x?xf32> to tensor<8x4xf32>
2061 /// ```
2062 ///
2063 /// folds into:
2064 ///
2065 /// ```mlir
2066 ///   %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2067 ///        : tensor<64x64xf32> to tensor<?x?xf32>
2068 ///   %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2069 ///     } : tensor<?x?xf32> to tensor<8x4xf32>
2070 /// ```
2071 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2072   using OpRewritePattern<PadOp>::OpRewritePattern;
2073 
matchAndRewrite__anon3fb9f79f1611::FoldOrthogonalPaddings2074   LogicalResult matchAndRewrite(PadOp padOp,
2075                                 PatternRewriter &rewriter) const override {
2076     auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2077     if (!innerSliceOp)
2078       return failure();
2079     auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2080     if (!outerPadOp || outerPadOp.getNofold())
2081       return failure();
2082     auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2083     if (!outerSliceOp)
2084       return failure();
2085 
2086     // 1) Fail if the chain is rank-reducing.
2087     int64_t rank = padOp.getSourceType().getRank();
2088     if (outerSliceOp.getSourceType().getRank() != rank) {
2089       return rewriter.notifyMatchFailure(padOp,
2090                                          "cannot fold rank-reducing chain");
2091     }
2092 
2093     // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2094     if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2095       return rewriter.notifyMatchFailure(
2096           padOp, "cannot fold non-unit stride ExtractSliceOps");
2097     }
2098 
2099     // 3) Fail if the tensor::PadOps have non-zero low padding.
2100     if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2101       return rewriter.notifyMatchFailure(padOp,
2102                                          "cannot fold PadOps with low padding");
2103     }
2104 
2105     // 4) Fail if the tensor::PadOps padding values do not match.
2106     Attribute innerAttr, outerAttr;
2107     Value innerValue = padOp.getConstantPaddingValue();
2108     Value outerValue = outerPadOp.getConstantPaddingValue();
2109     if (!innerValue || !outerValue ||
2110         !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2111         !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2112         innerAttr != outerAttr) {
2113       return rewriter.notifyMatchFailure(
2114           padOp, "cannot fold PadOps with different padding values");
2115     }
2116 
2117     // 5) Fail if a dimension is padded by both tensor::PadOps.
2118     llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2119     llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2120     if (innerDims.anyCommon(outerDims)) {
2121       return rewriter.notifyMatchFailure(
2122           padOp, "cannot fold PadOps with common padding dimensions");
2123     }
2124 
2125     // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2126     // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2127     // for every dimension, and use the offset the other pair. Fail if no
2128     // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2129     // exists.
2130     SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2131     for (auto &en : enumerate(newOffsets)) {
2132       OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2133       OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2134       if (!innerDims.test(en.index()) &&
2135           (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2136         en.value() = outerOffset;
2137         continue;
2138       }
2139       if (!outerDims.test(en.index()) &&
2140           (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2141         en.value() = innerOffset;
2142         continue;
2143       }
2144       return rewriter.notifyMatchFailure(
2145           padOp, "cannot find zero-offset and zero-padding pair");
2146     }
2147 
2148     // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2149     // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2150     // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2151     // does not match the size of the padded dimension. Otherwise, take the size
2152     // of the inner tensor::ExtractSliceOp.
2153     SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2154     for (auto &en : enumerate(newSizes)) {
2155       if (!outerDims.test(en.index()))
2156         continue;
2157       OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2158       int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2159       assert(!ShapedType::isDynamic(sourceSize) &&
2160              "expected padded dimension to have a static size");
2161       if (getConstantIntValue(sliceSize) != sourceSize) {
2162         return rewriter.notifyMatchFailure(
2163             padOp, "cannot fold since the inner ExtractSliceOp size does not "
2164                    "match the size of the outer padding");
2165       }
2166       en.value() = outerSliceOp.getMixedSizes()[en.index()];
2167     }
2168 
2169     // Combine the high paddings of the two tensor::PadOps.
2170     SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2171     for (auto &en : enumerate(newHighPad)) {
2172       if (innerDims.test(en.index()))
2173         newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2174       if (outerDims.test(en.index()))
2175         newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2176     }
2177 
2178     // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2179     // two paddings in one step.
2180     auto newSliceOp = rewriter.create<ExtractSliceOp>(
2181         padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2182         innerSliceOp.getMixedStrides());
2183     auto newPadOp = rewriter.create<PadOp>(
2184         padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2185         padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2186     rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2187                                 newPadOp.getRegion().begin());
2188     rewriter.replaceOp(padOp, newPadOp.getResult());
2189     return success();
2190   }
2191 };
2192 
2193 } // namespace
2194 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2195 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196                                         MLIRContext *context) {
2197   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2198               FoldOrthogonalPaddings>(context);
2199 }
2200 
2201 /// Return the padding value of the PadOp if it constant. In this context,
2202 /// "constant" means an actual constant or "defined outside of the block".
2203 ///
2204 /// Values are considered constant in three cases:
2205 ///  - A ConstantLike value.
2206 ///  - A basic block argument from a different block.
2207 ///  - A value defined outside of the block.
2208 ///
2209 /// If the padding value is not constant, an empty Value is returned.
getConstantPaddingValue()2210 Value PadOp::getConstantPaddingValue() {
2211   auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2212   if (!yieldOp)
2213     return {};
2214   Value padValue = yieldOp.getValue();
2215   // Check if yield value is a constant.
2216   if (matchPattern(padValue, m_Constant()))
2217     return padValue;
2218   // Check if yield value is defined inside the PadOp block.
2219   if (padValue.getParentBlock() == &getRegion().front())
2220     return {};
2221   // Else: Yield value defined outside of the PadOp block.
2222   return padValue;
2223 }
2224 
fold(ArrayRef<Attribute>)2225 OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
2226   if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
2227       !getNofold())
2228     return getSource();
2229   return {};
2230 }
2231 
2232 //===----------------------------------------------------------------------===//
2233 // ParallelInsertSliceOp
2234 //===----------------------------------------------------------------------===//
2235 
getTiedOpResult()2236 OpResult ParallelInsertSliceOp::getTiedOpResult() {
2237   ParallelCombiningOpInterface parallelCombiningParent =
2238       getParallelCombiningParent();
2239   for (const auto &it :
2240        llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
2241     Operation &nextOp = it.value();
2242     if (&nextOp == getOperation())
2243       return parallelCombiningParent.getParentResult(it.index());
2244   }
2245   llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
2246 }
2247 
2248 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)2249 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2250                                   Value source, Value dest,
2251                                   ArrayRef<OpFoldResult> offsets,
2252                                   ArrayRef<OpFoldResult> sizes,
2253                                   ArrayRef<OpFoldResult> strides,
2254                                   ArrayRef<NamedAttribute> attrs) {
2255   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2256   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2257   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2258                              ShapedType::kDynamicStrideOrOffset);
2259   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2260                              ShapedType::kDynamicSize);
2261   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2262                              ShapedType::kDynamicStrideOrOffset);
2263   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
2264         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
2265         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
2266   result.addAttributes(attrs);
2267 }
2268 
2269 // Build a ParallelInsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2270 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2271                                   Value source, Value dest, ValueRange offsets,
2272                                   ValueRange sizes, ValueRange strides,
2273                                   ArrayRef<NamedAttribute> attrs) {
2274   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2275       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2276   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2277       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2278   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2279       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2280   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2281 }
2282 
verify()2283 LogicalResult ParallelInsertSliceOp::verify() {
2284   if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
2285     return this->emitError("expected ParallelCombiningOpInterface parent, got:")
2286            << *(getOperation()->getParentOp());
2287 
2288   ShapedType expectedType;
2289   SliceVerificationResult result =
2290       verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
2291                           getStaticSizes(), getStaticStrides(), &expectedType);
2292   return produceSliceErrorMsg(result, *this, expectedType);
2293 }
2294 
2295 namespace {
2296 /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
2297 class ParallelInsertSliceOpConstantArgumentFolder final
2298     : public OpRewritePattern<ParallelInsertSliceOp> {
2299 public:
2300   using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
2301 
matchAndRewrite(ParallelInsertSliceOp insertSliceOp,PatternRewriter & rewriter) const2302   LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
2303                                 PatternRewriter &rewriter) const override {
2304     // No constant operand, just return.
2305     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
2306           return matchPattern(operand, matchConstantIndex());
2307         }))
2308       return failure();
2309 
2310     // At least one of offsets/sizes/strides is a new constant.
2311     // Form the new list of operands and constant attributes from the
2312     // existing.
2313     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2314     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2315     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2316     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
2317     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
2318     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
2319 
2320     // Create the new op in canonical form.
2321     auto sourceType =
2322         tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
2323             insertSliceOp.getSourceType().getRank(),
2324             insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
2325             mixedStrides);
2326     Value toInsert = insertSliceOp.getSource();
2327     if (sourceType != insertSliceOp.getSourceType()) {
2328       OpBuilder::InsertionGuard g(rewriter);
2329       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2330       toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2331                                                  sourceType, toInsert);
2332     }
2333     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
2334         insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2335         mixedSizes, mixedStrides);
2336     return success();
2337   }
2338 };
2339 } // namespace
2340 
2341 LogicalResult
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2342 ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
2343                             SmallVectorImpl<OpFoldResult> &results) {
2344   return foldInsertOp(*this, operands);
2345 }
2346 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2347 void ParallelInsertSliceOp::getCanonicalizationPatterns(
2348     RewritePatternSet &results, MLIRContext *context) {
2349   results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
2350               InsertSliceOpCastFolder<ParallelInsertSliceOp>,
2351               InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
2352 }
2353 
2354 //===----------------------------------------------------------------------===//
2355 // SplatOp
2356 //===----------------------------------------------------------------------===//
2357 
fold(ArrayRef<Attribute> operands)2358 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2359   auto constOperand = operands.front();
2360   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
2361     return {};
2362 
2363   // SplatElementsAttr::get treats single value for second arg as being a splat.
2364   return SplatElementsAttr::get(getType(), {constOperand});
2365 }
2366 
2367 //===----------------------------------------------------------------------===//
2368 // TableGen'd op method definitions
2369 //===----------------------------------------------------------------------===//
2370 
2371 #define GET_OP_CLASSES
2372 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
2373