1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
14 #include "mlir/Dialect/Utils/StaticValueUtils.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinAttributeInterfaces.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallBitVector.h"
22
23 using namespace mlir;
24 using namespace mlir::tensor;
25
26 /// Materialize a single constant operation from a given attribute value with
27 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)28 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
29 Attribute value, Type type,
30 Location loc) {
31 if (arith::ConstantOp::isBuildableWith(value, type))
32 return builder.create<arith::ConstantOp>(loc, value, type);
33 if (complex::ConstantOp::isBuildableWith(value, type))
34 return builder.create<complex::ConstantOp>(loc, type,
35 value.cast<ArrayAttr>());
36 return nullptr;
37 }
38
39 //===----------------------------------------------------------------------===//
40 // CastOp
41 //===----------------------------------------------------------------------===//
42
43 /// Returns true if `target` is a ranked tensor type that preserves static
44 /// information available in the `source` ranked tensor type.
preservesStaticInformation(Type source,Type target)45 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
46 auto sourceType = source.dyn_cast<RankedTensorType>();
47 auto targetType = target.dyn_cast<RankedTensorType>();
48
49 // Requires RankedTensorType.
50 if (!sourceType || !targetType)
51 return false;
52
53 // Requires same elemental type.
54 if (sourceType.getElementType() != targetType.getElementType())
55 return false;
56
57 // Requires same rank.
58 if (sourceType.getRank() != targetType.getRank())
59 return false;
60
61 // If cast is towards more static sizes along any dimension, don't fold.
62 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
63 if (!ShapedType::isDynamic(std::get<0>(t)) &&
64 ShapedType::isDynamic(std::get<1>(t)))
65 return false;
66 }
67
68 return true;
69 }
70
71 /// Determines whether tensor::CastOp casts to a more dynamic version of the
72 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
73 /// implement canonicalization patterns for ops in different dialects that may
74 /// consume the results of tensor.cast operations. Such foldable tensor.cast
75 /// operations are typically inserted as `slice` ops and are canonicalized,
76 /// to preserve the type compatibility of their uses.
77 ///
78 /// Returns true when all conditions are met:
79 /// 1. source and result are ranked tensors with same element type and rank.
80 /// 2. the tensor type has more static information than the result
81 ///
82 /// Example:
83 /// ```mlir
84 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
85 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
86 /// ```
87 ///
88 /// folds into:
89 ///
90 /// ```mlir
91 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
92 /// ```
canFoldIntoConsumerOp(CastOp castOp)93 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
94 if (!castOp)
95 return false;
96
97 // Can fold if the source of cast has at least as much static information as
98 // its results.
99 return preservesStaticInformation(castOp.getType(),
100 castOp.getSource().getType());
101 }
102
103 /// Determines whether the tensor::CastOp casts to a more static version of the
104 /// source tensor. This is useful to fold into a producing op and implement
105 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
106 /// being from different dialects. Returns true when all conditions are met:
107 /// 1. source and result and ranked tensors with same element type and rank.
108 /// 2. the result type has more static information than the source.
109 ///
110 /// Example:
111 /// ```mlir
112 /// %1 = producer ... : tensor<?x?xf32>
113 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
114 /// ```
115 ///
116 /// can be canonicalized to :
117 ///
118 /// ```mlir
119 /// %2 = producer ... : tensor<8x16xf32>
120 /// ```
121 /// Not all ops might be canonicalizable this way, but for those that can be,
122 /// this method provides a check that it is worth doing the canonicalization.
canFoldIntoProducerOp(CastOp castOp)123 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
124 if (!castOp)
125 return false;
126 return preservesStaticInformation(castOp.getSource().getType(),
127 castOp.getType());
128 }
129
130 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
131 /// that can be folded.
foldTensorCast(Operation * op)132 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
133 bool folded = false;
134 for (OpOperand &operand : op->getOpOperands()) {
135 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
136 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
137 operand.set(castOp.getOperand());
138 folded = true;
139 }
140 }
141 return success(folded);
142 }
143
areCastCompatible(TypeRange inputs,TypeRange outputs)144 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
145 if (inputs.size() != 1 || outputs.size() != 1)
146 return false;
147 Type a = inputs.front(), b = outputs.front();
148 auto aT = a.dyn_cast<TensorType>();
149 auto bT = b.dyn_cast<TensorType>();
150 if (!aT || !bT)
151 return false;
152
153 if (aT.getElementType() != bT.getElementType())
154 return false;
155
156 return succeeded(verifyCompatibleShape(aT, bT));
157 }
158
159 /// Compute a TensorType that has the joined shape knowledge of the two
160 /// given TensorTypes. The element types need to match.
joinShapes(TensorType one,TensorType two)161 static TensorType joinShapes(TensorType one, TensorType two) {
162 assert(one.getElementType() == two.getElementType());
163
164 if (!one.hasRank())
165 return two;
166 if (!two.hasRank())
167 return one;
168
169 int64_t rank = one.getRank();
170 if (rank != two.getRank())
171 return {};
172
173 SmallVector<int64_t, 4> join;
174 join.reserve(rank);
175 for (int64_t i = 0; i < rank; ++i) {
176 if (one.isDynamicDim(i)) {
177 join.push_back(two.getDimSize(i));
178 continue;
179 }
180 if (two.isDynamicDim(i)) {
181 join.push_back(one.getDimSize(i));
182 continue;
183 }
184 if (one.getDimSize(i) != two.getDimSize(i))
185 return {};
186 join.push_back(one.getDimSize(i));
187 }
188 return RankedTensorType::get(join, one.getElementType());
189 }
190
191 namespace {
192
193 /// Replaces chains of two tensor.cast operations by a single tensor.cast
194 /// operation if doing so does not remove runtime constraints.
195 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
196 using OpRewritePattern<CastOp>::OpRewritePattern;
197
matchAndRewrite__anon3fb9f79f0111::ChainedTensorCast198 LogicalResult matchAndRewrite(CastOp tensorCast,
199 PatternRewriter &rewriter) const final {
200 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
201
202 if (!tensorCastOperand)
203 return failure();
204
205 auto sourceType =
206 tensorCastOperand.getOperand().getType().cast<TensorType>();
207 auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
208 auto resultType = tensorCast.getType().cast<TensorType>();
209
210 // We can remove the intermediate cast if joining all three produces the
211 // same result as just joining the source and result shapes.
212 auto firstJoin =
213 joinShapes(joinShapes(sourceType, intermediateType), resultType);
214
215 // The join might not exist if the cast sequence would fail at runtime.
216 if (!firstJoin)
217 return failure();
218
219 // The newJoin always exists if the above join exists, it might just contain
220 // less information. If so, we cannot drop the intermediate cast, as doing
221 // so would remove runtime checks.
222 auto newJoin = joinShapes(sourceType, resultType);
223 if (firstJoin != newJoin)
224 return failure();
225
226 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
227 tensorCastOperand.getOperand());
228 return success();
229 }
230 };
231
232 /// Fold tensor.cast into tesor.extract_slice producer.
233 /// Example:
234 /// ```
235 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
236 /// tensor<128x512xf32> to tensor<?x512xf32>
237 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
238 /// ```
239 /// ->
240 /// ```
241 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
242 /// tensor<128x512xf32> to tensor<16x512xf32>
243 /// ```
244 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
245 using OpRewritePattern<CastOp>::OpRewritePattern;
246
matchAndRewrite__anon3fb9f79f0111::TensorCastExtractSlice247 LogicalResult matchAndRewrite(CastOp tensorCast,
248 PatternRewriter &rewriter) const final {
249 auto extractOperand =
250 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
251
252 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
253 tensorCast.getType().getShape() == tensorCast.getSource()
254 .getType()
255 .cast<RankedTensorType>()
256 .getShape())
257 return failure();
258
259 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
260 auto dimMask = computeRankReductionMask(
261 extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
262 extractOperand.getType().getShape());
263 size_t dimIndex = 0;
264 for (size_t i = 0, e = sizes.size(); i < e; i++) {
265 if (dimMask && dimMask->count(i))
266 continue;
267 int64_t dim = tensorCast.getType().getShape()[dimIndex++];
268 if (ShapedType::isDynamic(dim))
269 continue;
270 sizes[i] = rewriter.getIndexAttr(dim);
271 }
272
273 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
274 tensorCast, tensorCast.getType().cast<RankedTensorType>(),
275 extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
276 extractOperand.getMixedStrides());
277 return success();
278 }
279 };
280
281 } // namespace
282
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)283 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
284 MLIRContext *context) {
285 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
286 }
287
288 //===----------------------------------------------------------------------===//
289 // DimOp
290 //===----------------------------------------------------------------------===//
291
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)292 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
293 int64_t index) {
294 auto loc = result.location;
295 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
296 build(builder, result, source, indexValue);
297 }
298
getConstantIndex()299 Optional<int64_t> DimOp::getConstantIndex() {
300 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
301 return constantOp.getValue().cast<IntegerAttr>().getInt();
302 return {};
303 }
304
verify()305 LogicalResult DimOp::verify() {
306 // Assume unknown index to be in range.
307 Optional<int64_t> index = getConstantIndex();
308 if (!index)
309 return success();
310
311 // Check that constant index is not knowingly out of range.
312 auto type = getSource().getType();
313 if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
314 if (*index >= tensorType.getRank())
315 return emitOpError("index is out of range");
316 } else if (type.isa<UnrankedTensorType>()) {
317 // Assume index to be in range.
318 } else {
319 llvm_unreachable("expected operand with tensor type");
320 }
321 return success();
322 }
323
fold(ArrayRef<Attribute> operands)324 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
325 // All forms of folding require a known index.
326 auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
327 if (!index)
328 return {};
329
330 // Folding for unranked types (UnrankedTensorType) is not supported.
331 auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
332 if (!tensorType)
333 return {};
334
335 // Fold if the shape extent along the given index is known.
336 if (!tensorType.isDynamicDim(index.getInt())) {
337 Builder builder(getContext());
338 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
339 }
340
341 Operation *definingOp = getSource().getDefiningOp();
342
343 // Fold dim to the operand of tensor.generate.
344 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
345 auto resultType =
346 fromElements.getResult().getType().cast<RankedTensorType>();
347 // The case where the type encodes the size of the dimension is handled
348 // above.
349 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
350
351 // Find the operand of the fromElements that corresponds to this index.
352 auto dynExtents = fromElements.getDynamicExtents().begin();
353 for (auto dim : resultType.getShape().take_front(index.getInt()))
354 if (ShapedType::isDynamic(dim))
355 dynExtents++;
356
357 return Value{*dynExtents};
358 }
359
360 // The size at the given index is now known to be a dynamic size.
361 unsigned unsignedIndex = index.getValue().getZExtValue();
362
363 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
364 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
365 // `resolve-shaped-type-result-dims` pass.
366 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
367 sliceOp.isDynamicSize(unsignedIndex)) {
368 return {sliceOp.getDynamicSize(unsignedIndex)};
369 }
370 }
371
372 // dim(cast) -> dim
373 if (succeeded(foldTensorCast(*this)))
374 return getResult();
375
376 return {};
377 }
378
379 namespace {
380 /// Fold dim of a cast into the dim of the source of the tensor cast.
381 struct DimOfCastOp : public OpRewritePattern<DimOp> {
382 using OpRewritePattern<DimOp>::OpRewritePattern;
383
matchAndRewrite__anon3fb9f79f0211::DimOfCastOp384 LogicalResult matchAndRewrite(DimOp dimOp,
385 PatternRewriter &rewriter) const override {
386 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
387 if (!castOp)
388 return failure();
389 Value newSource = castOp.getOperand();
390 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
391 return success();
392 }
393 };
394 } // namespace
395
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)396 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
397 MLIRContext *context) {
398 results.add<DimOfCastOp>(context);
399 }
400
401 //===----------------------------------------------------------------------===//
402 // ExtractOp
403 //===----------------------------------------------------------------------===//
404
verify()405 LogicalResult ExtractOp::verify() {
406 // Verify the # indices match if we have a ranked type.
407 if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
408 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
409 return emitOpError("incorrect number of indices for extract_element");
410
411 return success();
412 }
413
fold(ArrayRef<Attribute> operands)414 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
415 // If this is a splat elements attribute, simply return the value. All of the
416 // elements of a splat attribute are the same.
417 if (Attribute tensor = operands.front())
418 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
419 return splatTensor.getSplatValue<Attribute>();
420
421 // Collect the constant indices into the tensor.
422 SmallVector<uint64_t, 8> indices;
423 for (Attribute indice : llvm::drop_begin(operands, 1)) {
424 if (!indice || !indice.isa<IntegerAttr>())
425 return {};
426 indices.push_back(indice.cast<IntegerAttr>().getInt());
427 }
428
429 // Fold extract(from_elements(...)).
430 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
431 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
432 auto rank = tensorType.getRank();
433 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
434 "rank mismatch");
435 int flatIndex = 0;
436 int stride = 1;
437 for (int i = rank - 1; i >= 0; --i) {
438 if (i < rank - 1)
439 stride *= tensorType.getDimSize(i);
440 flatIndex += indices[i] * stride;
441 }
442 // Prevent out of bounds accesses. This can happen in invalid code that will
443 // never execute.
444 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
445 flatIndex < 0)
446 return {};
447 return fromElementsOp.getElements()[flatIndex];
448 }
449
450 // If this is an elements attribute, query the value at the given indices.
451 if (Attribute tensor = operands.front()) {
452 auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
453 if (elementsAttr && elementsAttr.isValidIndex(indices))
454 return elementsAttr.getValues<Attribute>()[indices];
455 }
456
457 return {};
458 }
459
460 //===----------------------------------------------------------------------===//
461 // FromElementsOp
462 //===----------------------------------------------------------------------===//
463
build(OpBuilder & builder,OperationState & result,Type resultType,ValueRange elements)464 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
465 Type resultType, ValueRange elements) {
466 result.addOperands(elements);
467 result.addTypes(resultType);
468 }
469
build(OpBuilder & builder,OperationState & result,ValueRange elements)470 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
471 ValueRange elements) {
472 assert(!elements.empty() && "expected at least one element");
473 Type resultType = RankedTensorType::get(
474 {static_cast<int64_t>(elements.size())}, elements.front().getType());
475 build(builder, result, resultType, elements);
476 }
477
fold(ArrayRef<Attribute> operands)478 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
479 if (!llvm::is_contained(operands, nullptr))
480 return DenseElementsAttr::get(getType(), operands);
481 return {};
482 }
483
484 namespace {
485
486 // Pushes the index_casts that occur before extractions to after the extract.
487 // This minimizes type conversion in some cases and enables the extract
488 // canonicalizer. This changes:
489 //
490 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
491 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
492 //
493 // to the following:
494 //
495 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
496 // %cast = arith.index_cast %extract : i32 to index
497 //
498 // to just %element.
499 //
500 // Consider expanding this to a template and handle all tensor cast operations.
501 struct ExtractElementFromIndexCast
502 : public OpRewritePattern<tensor::ExtractOp> {
503 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
504
matchAndRewrite__anon3fb9f79f0311::ExtractElementFromIndexCast505 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
506 PatternRewriter &rewriter) const final {
507 Location loc = extract.getLoc();
508 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
509 if (!indexCast)
510 return failure();
511
512 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
513
514 auto newExtract = rewriter.create<tensor::ExtractOp>(
515 loc, elementTy, indexCast.getIn(), extract.getIndices());
516
517 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
518 newExtract);
519
520 return success();
521 }
522 };
523
524 } // namespace
525
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)526 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
527 MLIRContext *context) {
528 results.add<ExtractElementFromIndexCast>(context);
529 }
530
531 //===----------------------------------------------------------------------===//
532 // InsertOp
533 //===----------------------------------------------------------------------===//
534
verify()535 LogicalResult InsertOp::verify() {
536 // Verify the # indices match if we have a ranked type.
537 if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
538 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
539 return emitOpError("incorrect number of indices");
540 return success();
541 }
542
fold(ArrayRef<Attribute> operands)543 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
544 Attribute scalar = operands[0];
545 Attribute dest = operands[1];
546 if (scalar && dest)
547 if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
548 if (scalar == splatDest.getSplatValue<Attribute>())
549 return dest;
550 return {};
551 }
552
553 //===----------------------------------------------------------------------===//
554 // GenerateOp
555 //===----------------------------------------------------------------------===//
556
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)557 LogicalResult GenerateOp::reifyResultShapes(
558 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
559 reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
560 int idx = 0;
561 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
562 if (getType().isDynamicDim(dim)) {
563 reifiedReturnShapes[0][dim] = getOperand(idx++);
564 } else {
565 reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
566 getLoc(), getType().getDimSize(dim));
567 }
568 }
569 return success();
570 }
571
verify()572 LogicalResult GenerateOp::verify() {
573 // Ensure that the tensor type has as many dynamic dimensions as are specified
574 // by the operands.
575 RankedTensorType resultTy = getType().cast<RankedTensorType>();
576 if (getNumOperands() != resultTy.getNumDynamicDims())
577 return emitError("must have as many index operands as dynamic extents "
578 "in the result type");
579
580 return success();
581 }
582
verifyRegions()583 LogicalResult GenerateOp::verifyRegions() {
584 RankedTensorType resultTy = getType().cast<RankedTensorType>();
585 // Ensure that region arguments span the index space.
586 if (!llvm::all_of(getBody().getArgumentTypes(),
587 [](Type ty) { return ty.isIndex(); }))
588 return emitError("all body arguments must be index");
589 if (getBody().getNumArguments() != resultTy.getRank())
590 return emitError("must have one body argument per input dimension");
591
592 // Ensure that the region yields an element of the right type.
593 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
594
595 if (yieldOp.getValue().getType() != resultTy.getElementType())
596 return emitOpError(
597 "body must be terminated with a `yield` operation of the tensor "
598 "element type");
599
600 return success();
601 }
602
build(OpBuilder & b,OperationState & result,Type resultTy,ValueRange dynamicExtents,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)603 void GenerateOp::build(
604 OpBuilder &b, OperationState &result, Type resultTy,
605 ValueRange dynamicExtents,
606 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
607 build(b, result, resultTy, dynamicExtents);
608
609 // Build and populate body.
610 OpBuilder::InsertionGuard guard(b);
611 Region *bodyRegion = result.regions.front().get();
612 auto rank = resultTy.cast<RankedTensorType>().getRank();
613 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
614 SmallVector<Location, 2> argumentLocs(rank, result.location);
615 Block *bodyBlock =
616 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
617 bodyBuilder(b, result.location, bodyBlock->getArguments());
618 }
619
620 namespace {
621
622 /// Canonicalizes tensor.generate operations with a constant
623 /// operand into the equivalent operation with the operand expressed in the
624 /// result type, instead. We also insert a type cast to make sure that the
625 /// resulting IR is still well-typed.
626 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
627 using OpRewritePattern<GenerateOp>::OpRewritePattern;
628
matchAndRewrite__anon3fb9f79f0511::StaticTensorGenerate629 LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
630 PatternRewriter &rewriter) const final {
631 auto resultType =
632 tensorFromElements.getResult().getType().cast<RankedTensorType>();
633
634 if (resultType.hasStaticShape())
635 return failure();
636
637 SmallVector<Value, 4> newOperands;
638 SmallVector<int64_t, 4> newShape;
639 auto operandsIt = tensorFromElements.getDynamicExtents().begin();
640
641 for (int64_t dim : resultType.getShape()) {
642 if (!ShapedType::isDynamic(dim)) {
643 newShape.push_back(dim);
644 continue;
645 }
646 APInt index;
647 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
648 newShape.push_back(ShapedType::kDynamicSize);
649 newOperands.push_back(*operandsIt++);
650 continue;
651 }
652 newShape.push_back(index.getSExtValue());
653 operandsIt++;
654 }
655
656 if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
657 return failure();
658
659 auto loc = tensorFromElements.getLoc();
660 auto newOp = rewriter.create<GenerateOp>(
661 loc, RankedTensorType::get(newShape, resultType.getElementType()),
662 newOperands);
663 rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
664 newOp.getBody().begin());
665 rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
666 newOp);
667 return success();
668 }
669 };
670
671 /// Canonicalizes the pattern of the form
672 ///
673 /// %tensor = tensor.generate %x {
674 /// ^bb0(%arg0: index):
675 /// <computation>
676 /// yield %1 : index
677 /// } : tensor<?xindex>
678 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
679 ///
680 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
681 /// tensor.generate operation has no side-effects.
682 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
683 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
684
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorGenerate685 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
686 PatternRewriter &rewriter) const final {
687 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
688 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
689 return failure();
690
691 BlockAndValueMapping mapping;
692 Block *body = &tensorFromElements.getBody().front();
693 mapping.map(body->getArguments(), extract.getIndices());
694 for (auto &op : body->without_terminator())
695 rewriter.clone(op, mapping);
696
697 auto yield = cast<YieldOp>(body->getTerminator());
698
699 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
700 return success();
701 }
702 };
703
704 /// Canonicalizes the pattern of the form
705 ///
706 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
707 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
708 ///
709 /// to
710 ///
711 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
712 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
713 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
714
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorCast715 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
716 PatternRewriter &rewriter) const final {
717 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
718 if (!tensorCast)
719 return failure();
720
721 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
722 extract, tensorCast.getSource(), extract.getIndices());
723 return success();
724 }
725 };
726
727 } // namespace
728
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)729 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
730 MLIRContext *context) {
731 // TODO: Move extract patterns to tensor::ExtractOp.
732 results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
733 StaticTensorGenerate>(context);
734 }
735
736 //===----------------------------------------------------------------------===//
737 // RankOp
738 //===----------------------------------------------------------------------===//
739
fold(ArrayRef<Attribute> operands)740 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
741 // Constant fold rank when the rank of the operand is known.
742 auto type = getOperand().getType();
743 auto shapedType = type.dyn_cast<ShapedType>();
744 if (shapedType && shapedType.hasRank())
745 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
746 return IntegerAttr();
747 }
748
749 //===----------------------------------------------------------------------===//
750 // ReshapeOp
751 //===----------------------------------------------------------------------===//
752
getNumElements(ShapedType type)753 static int64_t getNumElements(ShapedType type) {
754 int64_t numElements = 1;
755 for (auto dim : type.getShape())
756 numElements *= dim;
757 return numElements;
758 }
759
verify()760 LogicalResult ReshapeOp::verify() {
761 TensorType operandType = getSource().getType().cast<TensorType>();
762 TensorType resultType = getResult().getType().cast<TensorType>();
763
764 if (operandType.getElementType() != resultType.getElementType())
765 return emitOpError("element types of source and destination tensor "
766 "types should be the same");
767
768 int64_t shapeSize =
769 getShape().getType().cast<RankedTensorType>().getDimSize(0);
770 auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
771 auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
772
773 if (resultRankedType) {
774 if (operandRankedType && resultRankedType.hasStaticShape() &&
775 operandRankedType.hasStaticShape()) {
776 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
777 return emitOpError("source and destination tensor should have the "
778 "same number of elements");
779 }
780 if (ShapedType::isDynamic(shapeSize))
781 return emitOpError("cannot use shape operand with dynamic length to "
782 "reshape to statically-ranked tensor type");
783 if (shapeSize != resultRankedType.getRank())
784 return emitOpError(
785 "length of shape operand differs from the result's tensor rank");
786 }
787 return success();
788 }
789
790 //===----------------------------------------------------------------------===//
791 // Reassociative reshape ops
792 //===----------------------------------------------------------------------===//
793
getReassociationMaps()794 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
795 return getSymbolLessAffineMaps(getReassociationExprs());
796 }
getReassociationExprs()797 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
798 return convertReassociationIndicesToExprs(getContext(),
799 getReassociationIndices());
800 }
801
getReassociationMaps()802 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
803 return getSymbolLessAffineMaps(getReassociationExprs());
804 }
getReassociationExprs()805 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
806 return convertReassociationIndicesToExprs(getContext(),
807 getReassociationIndices());
808 }
809
810 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
811 static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,ArrayRef<AffineMap> reassociation)812 computeTensorReshapeCollapsedType(RankedTensorType type,
813 ArrayRef<AffineMap> reassociation) {
814 auto shape = type.getShape();
815 SmallVector<int64_t, 4> newShape;
816 newShape.reserve(reassociation.size());
817
818 // Use the fact that reassociation is valid to simplify the logic: only use
819 // each map's rank.
820 assert(isReassociationValid(reassociation) && "invalid reassociation");
821 unsigned currentDim = 0;
822 for (AffineMap m : reassociation) {
823 unsigned dim = m.getNumResults();
824 auto band = shape.slice(currentDim, dim);
825 int64_t size = 1;
826 if (llvm::is_contained(band, ShapedType::kDynamicSize))
827 size = ShapedType::kDynamicSize;
828 else
829 for (unsigned d = 0; d < dim; ++d)
830 size *= shape[currentDim + d];
831 newShape.push_back(size);
832 currentDim += dim;
833 }
834
835 return RankedTensorType::get(newShape, type.getElementType());
836 }
837
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)838 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
839 ArrayRef<ReassociationIndices> reassociation,
840 ArrayRef<NamedAttribute> attrs) {
841 auto resultType = computeTensorReshapeCollapsedType(
842 src.getType().cast<RankedTensorType>(),
843 getSymbolLessAffineMaps(
844 convertReassociationIndicesToExprs(b.getContext(), reassociation)));
845 build(b, result, resultType, src, attrs);
846 result.addAttribute(getReassociationAttrStrName(),
847 getReassociationIndicesAttribute(b, reassociation));
848 }
849
850 // Checks if types are the same, but ignoring encoding on ranked tensors.
isSameTypesWithoutEncoding(Type tp1,Type tp2)851 static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
852 if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
853 if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
854 return rtp1.getShape() == rtp2.getShape() &&
855 rtp1.getElementType() == rtp2.getElementType();
856 return false;
857 }
858 // Default implementation.
859 return tp1 == tp2;
860 }
861
862 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
863 TensorReshapeOp, ExpandShapeOp>::value>
verifyTensorReshapeOp(TensorReshapeOp op,RankedTensorType expandedType,RankedTensorType collapsedType)864 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
865 RankedTensorType expandedType,
866 RankedTensorType collapsedType) {
867 if (failed(
868 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
869 return failure();
870
871 auto maps = op.getReassociationMaps();
872 RankedTensorType expectedType =
873 computeTensorReshapeCollapsedType(expandedType, maps);
874 if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
875 return op.emitOpError("expected collapsed type to be ")
876 << expectedType << ", but got " << collapsedType;
877 return success();
878 }
879
verify()880 LogicalResult ExpandShapeOp::verify() {
881 return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
882 }
883
verify()884 LogicalResult CollapseShapeOp::verify() {
885 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
886 }
887
888 namespace {
889 /// Reshape of a splat constant can be replaced with a constant of the result
890 /// type.
891 template <typename TensorReshapeOp>
892 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
893 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithConstant894 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
895 PatternRewriter &rewriter) const override {
896 DenseElementsAttr attr;
897 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
898 return failure();
899 if (!attr || !attr.isSplat())
900 return failure();
901 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
902 reshapeOp.getResultType(), attr.getRawData());
903 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
904 return success();
905 }
906 };
907
908 /// Reshape of a FromElements can be replaced with a FromElements of the result
909 /// type
910 template <typename TensorReshapeOp>
911 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
912 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithFromElements913 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
914 PatternRewriter &rewriter) const override {
915 auto fromElements =
916 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
917 if (!fromElements)
918 return failure();
919
920 auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
921
922 if (!shapedTy.hasStaticShape())
923 return failure();
924
925 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
926 fromElements.getElements());
927 return success();
928 }
929 };
930
931 } // namespace
932
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)933 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
934 MLIRContext *context) {
935 results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
936 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
937 FoldReshapeWithConstant<ExpandShapeOp>,
938 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
939 }
940
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)941 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
942 MLIRContext *context) {
943 results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
944 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
945 FoldReshapeWithConstant<CollapseShapeOp>,
946 FoldReshapeWithFromElements<CollapseShapeOp>>(context);
947 }
948
fold(ArrayRef<Attribute> operands)949 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
950 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
951 }
fold(ArrayRef<Attribute> operands)952 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
953 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
954 }
955
956 //===----------------------------------------------------------------------===//
957 // ExtractSliceOp
958 //===----------------------------------------------------------------------===//
959
960 /// An extract_slice result type can be inferred, when it is not
961 /// rank-reduced, from the source type and the static representation of
962 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)963 RankedTensorType ExtractSliceOp::inferResultType(
964 ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
965 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
966 // An extract_slice op may specify only a leading subset of offset/sizes/
967 // strides in which case we complete with offset=0, sizes from memref type and
968 // strides=1.
969 assert(static_cast<int64_t>(staticSizes.size()) ==
970 sourceShapedTensorType.getRank() &&
971 "unexpected staticSizes not equal to rank of source");
972 return RankedTensorType::get(staticSizes,
973 sourceShapedTensorType.getElementType());
974 }
975
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)976 RankedTensorType ExtractSliceOp::inferResultType(
977 ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
978 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
979 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
980 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
981 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
982 ShapedType::kDynamicStrideOrOffset);
983 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
984 ShapedType::kDynamicSize);
985 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
986 ShapedType::kDynamicStrideOrOffset);
987 return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
988 staticSizes, staticStrides);
989 }
990
991 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
992 /// number of sizes), drop as many size 1 as needed to produce an inferred type
993 /// with the desired rank.
994 ///
995 /// Note that there may be multiple ways to compute this rank-reduced type:
996 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
997 ///
998 /// To disambiguate, this function always drops the first 1 sizes occurrences.
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)999 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1000 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1001 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1002 ArrayRef<int64_t> strides) {
1003 // Type inferred in the absence of rank-reducing behavior.
1004 auto inferredType =
1005 inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1006 .cast<RankedTensorType>();
1007 int rankDiff = inferredType.getRank() - desiredResultRank;
1008 if (rankDiff > 0) {
1009 auto shape = inferredType.getShape();
1010 llvm::SmallBitVector dimsToProject =
1011 getPositionsOfShapeOne(rankDiff, shape);
1012 SmallVector<int64_t> projectedShape;
1013 // Best effort rank-reducing: drop 1s in order.
1014 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1015 if (!dimsToProject.test(pos))
1016 projectedShape.push_back(shape[pos]);
1017 inferredType =
1018 RankedTensorType::get(projectedShape, inferredType.getElementType());
1019 }
1020 return inferredType;
1021 }
1022
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)1023 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1024 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1025 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1026 ArrayRef<OpFoldResult> strides) {
1027 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1028 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1029 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1030 ShapedType::kDynamicStrideOrOffset);
1031 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1032 ShapedType::kDynamicSize);
1033 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1034 ShapedType::kDynamicStrideOrOffset);
1035 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1036 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1037 staticStrides);
1038 }
1039
1040 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1041 /// result type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1042 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1043 RankedTensorType resultType, Value source,
1044 ArrayRef<OpFoldResult> offsets,
1045 ArrayRef<OpFoldResult> sizes,
1046 ArrayRef<OpFoldResult> strides,
1047 ArrayRef<NamedAttribute> attrs) {
1048 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1049 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1050 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1051 ShapedType::kDynamicStrideOrOffset);
1052 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1053 ShapedType::kDynamicSize);
1054 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1055 ShapedType::kDynamicStrideOrOffset);
1056 auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1057 // Structuring implementation this way avoids duplication between builders.
1058 if (!resultType) {
1059 resultType =
1060 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1061 staticSizes, staticStrides)
1062 .cast<RankedTensorType>();
1063 }
1064 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1065 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1066 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1067 result.addAttributes(attrs);
1068 }
1069
1070 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1071 /// result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1072 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1073 ArrayRef<OpFoldResult> offsets,
1074 ArrayRef<OpFoldResult> sizes,
1075 ArrayRef<OpFoldResult> strides,
1076 ArrayRef<NamedAttribute> attrs) {
1077 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1078 }
1079
1080 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
1081 /// type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1082 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1083 RankedTensorType resultType, Value source,
1084 ValueRange offsets, ValueRange sizes,
1085 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1086 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1087 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1088 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1089 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1090 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1091 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1092 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1093 }
1094
1095 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1096 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1097 ValueRange offsets, ValueRange sizes,
1098 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1099 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1100 }
1101
1102 template <typename OpTy>
produceSliceErrorMsg(SliceVerificationResult result,OpTy op,Type expectedType)1103 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
1104 OpTy op, Type expectedType) {
1105 auto memrefType = expectedType.cast<ShapedType>();
1106 switch (result) {
1107 case SliceVerificationResult::Success:
1108 return success();
1109 case SliceVerificationResult::RankTooLarge:
1110 return op.emitError("expected rank to be smaller or equal to ")
1111 << "the other rank. ";
1112 case SliceVerificationResult::SizeMismatch:
1113 return op.emitError("expected type to be ")
1114 << expectedType << " or a rank-reduced version. (size mismatch) ";
1115 case SliceVerificationResult::ElemTypeMismatch:
1116 return op.emitError("expected element type to be ")
1117 << memrefType.getElementType();
1118 default:
1119 llvm_unreachable("unexpected extract_slice op verification result");
1120 }
1121 }
1122
1123 /// Verifier for ExtractSliceOp.
verify()1124 LogicalResult ExtractSliceOp::verify() {
1125 // Verify result type against inferred type.
1126 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1127 getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1128 SliceVerificationResult result = isRankReducedType(expectedType, getType());
1129 return produceSliceErrorMsg(result, *this, expectedType);
1130 }
1131
getDroppedDims()1132 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1133 ArrayRef<int64_t> resultShape = getType().getShape();
1134 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
1135 llvm::SmallBitVector droppedDims(mixedSizes.size());
1136 unsigned shapePos = 0;
1137 for (const auto &size : enumerate(mixedSizes)) {
1138 Optional<int64_t> sizeVal = getConstantIntValue(size.value());
1139 // If the size is not 1, or if the current matched dimension of the result
1140 // is the same static shape as the size value (which is 1), then the
1141 // dimension is preserved.
1142 if (!sizeVal || *sizeVal != 1 ||
1143 (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
1144 shapePos++;
1145 continue;
1146 }
1147 droppedDims.set(size.index());
1148 }
1149 return droppedDims;
1150 }
1151
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)1152 LogicalResult ExtractSliceOp::reifyResultShapes(
1153 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1154 reifiedReturnShapes.resize(1);
1155 reifiedReturnShapes[0].reserve(getType().getRank());
1156 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
1157 llvm::SmallBitVector droppedDims = getDroppedDims();
1158 Location loc = getLoc();
1159 for (const auto &size : enumerate(mixedSizes)) {
1160 if (droppedDims.test(size.index()))
1161 continue;
1162 if (auto attr = size.value().dyn_cast<Attribute>()) {
1163 reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
1164 loc, attr.cast<IntegerAttr>().getInt()));
1165 continue;
1166 }
1167 reifiedReturnShapes[0].push_back(size.value().get<Value>());
1168 }
1169 return success();
1170 }
1171
1172 namespace {
1173 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1174 /// This essentially pushes memref_cast past its consuming slice when
1175 /// `canFoldIntoConsumerOp` is true.
1176 ///
1177 /// Example:
1178 /// ```
1179 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1180 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1181 /// tensor<3x4xf32>
1182 /// ```
1183 /// is rewritten into:
1184 /// ```
1185 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1186 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1187 /// ```
1188 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1189 public:
1190 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
1191
matchAndRewrite(ExtractSliceOp sliceOp,PatternRewriter & rewriter) const1192 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1193 PatternRewriter &rewriter) const override {
1194 // Any constant operand, just return to let the constant folder kick in.
1195 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1196 return matchPattern(operand, matchConstantIndex());
1197 }))
1198 return failure();
1199
1200 auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1201 if (!castOp)
1202 return failure();
1203
1204 if (!canFoldIntoConsumerOp(castOp))
1205 return failure();
1206
1207 /// Deduce the type of the result to use for the canonicalized operation.
1208 RankedTensorType resultType =
1209 ExtractSliceOp::inferCanonicalRankReducedResultType(
1210 sliceOp.getType().getRank(), sliceOp.getSourceType(),
1211 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1212 sliceOp.getMixedStrides());
1213 Value newSlice = rewriter.create<ExtractSliceOp>(
1214 sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
1215 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1216 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1217 rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1218 newSlice);
1219 return success();
1220 }
1221 };
1222
1223 /// Slice elements from `values` into `outValues`. `counts` represents the
1224 /// numbers of elements to stride in the original values for each dimension.
1225 /// The output values can be used to construct a DenseElementsAttr.
1226 template <typename IterTy, typename ElemTy>
sliceElements(IterTy values,ArrayRef<int64_t> counts,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<ElemTy> * outValues)1227 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
1228 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1229 ArrayRef<int64_t> strides,
1230 llvm::SmallVectorImpl<ElemTy> *outValues) {
1231 assert(offsets.size() == sizes.size());
1232 assert(offsets.size() == strides.size());
1233 if (offsets.empty())
1234 return;
1235
1236 int64_t offset = offsets.front();
1237 int64_t size = sizes.front();
1238 int64_t stride = strides.front();
1239 if (offsets.size() == 1) {
1240 for (int64_t i = 0; i < size; ++i, offset += stride)
1241 outValues->push_back(*(values + offset));
1242
1243 return;
1244 }
1245
1246 for (int64_t i = 0; i < size; ++i, offset += stride) {
1247 auto begin = values + offset * counts.front();
1248 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1249 offsets.drop_front(), sizes.drop_front(),
1250 strides.drop_front(), outValues);
1251 }
1252 }
1253
1254 /// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
1255 /// operation might introduce more constant data; Users can control their
1256 /// heuristics by the control function.
1257 class ConstantOpExtractSliceFolder final
1258 : public OpRewritePattern<ExtractSliceOp> {
1259 public:
1260 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
1261
ConstantOpExtractSliceFolder(MLIRContext * context,ControlConstantExtractSliceFusionFn controlFn)1262 ConstantOpExtractSliceFolder(MLIRContext *context,
1263 ControlConstantExtractSliceFusionFn controlFn)
1264 : OpRewritePattern<ExtractSliceOp>(context),
1265 controlFn(std::move(controlFn)) {}
1266
matchAndRewrite(ExtractSliceOp op,PatternRewriter & rewriter) const1267 LogicalResult matchAndRewrite(ExtractSliceOp op,
1268 PatternRewriter &rewriter) const override {
1269 DenseElementsAttr attr;
1270 if (!matchPattern(op.getSource(), m_Constant(&attr)))
1271 return failure();
1272
1273 // A constant splat is handled by fold().
1274 if (attr.isSplat())
1275 return failure();
1276
1277 // Dynamic result shape is not supported.
1278 auto sourceType = op.getSource().getType().cast<ShapedType>();
1279 auto resultType = op.getResult().getType().cast<ShapedType>();
1280 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1281 return failure();
1282
1283 // Customized control over the folding.
1284 if (!controlFn(op))
1285 return failure();
1286
1287 int64_t count = sourceType.getNumElements();
1288 if (count == 0)
1289 return failure();
1290
1291 // Check if there are any dynamic parts, which are not supported.
1292 auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
1293 if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
1294 return failure();
1295 auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
1296 if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
1297 return failure();
1298 auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
1299 if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
1300 return failure();
1301
1302 // Compute the stride for each dimension.
1303 SmallVector<int64_t> counts;
1304 ArrayRef<int64_t> shape = sourceType.getShape();
1305 counts.reserve(shape.size());
1306 for (int64_t v : shape) {
1307 count = count / v;
1308 counts.push_back(count);
1309 }
1310
1311 // New attribute constructed by the sliced values.
1312 DenseElementsAttr newAttr;
1313
1314 if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
1315 SmallVector<APInt> outValues;
1316 outValues.reserve(sourceType.getNumElements());
1317 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1318 elems.begin(), counts, offsets, sizes, strides, &outValues);
1319 newAttr = DenseElementsAttr::get(resultType, outValues);
1320 } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
1321 SmallVector<APFloat> outValues;
1322 outValues.reserve(sourceType.getNumElements());
1323 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1324 elems.begin(), counts, offsets, sizes, strides, &outValues);
1325 newAttr = DenseElementsAttr::get(resultType, outValues);
1326 }
1327
1328 if (newAttr) {
1329 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
1330 return success();
1331 }
1332
1333 return failure();
1334 }
1335
1336 private:
1337 /// This additionally controls whether the fold happens or not. Users can
1338 /// impose their heuristics in the function.
1339 ControlConstantExtractSliceFusionFn controlFn;
1340 };
1341
1342 } // namespace
1343
populateFoldConstantExtractSlicePatterns(RewritePatternSet & patterns,const ControlConstantExtractSliceFusionFn & controlFn)1344 void mlir::tensor::populateFoldConstantExtractSlicePatterns(
1345 RewritePatternSet &patterns,
1346 const ControlConstantExtractSliceFusionFn &controlFn) {
1347 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
1348 }
1349
1350 /// Return the canonical type of the result of an extract_slice op.
1351 struct SliceReturnTypeCanonicalizer {
operator ()SliceReturnTypeCanonicalizer1352 RankedTensorType operator()(ExtractSliceOp op,
1353 ArrayRef<OpFoldResult> mixedOffsets,
1354 ArrayRef<OpFoldResult> mixedSizes,
1355 ArrayRef<OpFoldResult> mixedStrides) {
1356 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1357 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
1358 mixedStrides);
1359 }
1360 };
1361
1362 /// A canonicalizer wrapper to replace ExtractSliceOps.
1363 struct SliceCanonicalizer {
operator ()SliceCanonicalizer1364 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1365 ExtractSliceOp newOp) {
1366 Value replacement = newOp.getResult();
1367 if (replacement.getType() != op.getType())
1368 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1369 replacement);
1370 rewriter.replaceOp(op, replacement);
1371 }
1372 };
1373
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1374 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375 MLIRContext *context) {
1376 results.add<
1377 OpWithOffsetSizesAndStridesConstantArgumentFolder<
1378 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
1379 ExtractSliceOpCastFolder>(context);
1380 }
1381
1382 //
1383 static LogicalResult
foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,ShapedType shapedType)1384 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1385 ShapedType shapedType) {
1386 OpBuilder b(op.getContext());
1387 for (OpFoldResult ofr : op.getMixedOffsets())
1388 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1389 return failure();
1390 // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1391 // is appropriate.
1392 auto shape = shapedType.getShape();
1393 for (auto it : llvm::zip(op.getMixedSizes(), shape))
1394 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1395 return failure();
1396 for (OpFoldResult ofr : op.getMixedStrides())
1397 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1398 return failure();
1399 return success();
1400 }
1401
1402 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1403 /// we can return the InsertSliceOp's source directly.
1404 // TODO: This only checks the immediate producer; extend to go up the
1405 // insert/extract chain if the slices are disjoint.
foldExtractAfterInsertSlice(ExtractSliceOp extractOp)1406 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
1407 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
1408
1409 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1410 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
1411 insertOp.isSameAs(extractOp, isSame))
1412 return insertOp.getSource();
1413
1414 return {};
1415 }
1416
fold(ArrayRef<Attribute> operands)1417 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
1418 if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
1419 auto resultType = getResult().getType().cast<ShapedType>();
1420 if (resultType.hasStaticShape())
1421 return splat.resizeSplat(resultType);
1422 }
1423 if (getSourceType() == getType() &&
1424 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1425 return this->getSource();
1426 if (Value slice = foldExtractAfterInsertSlice(*this))
1427 return slice;
1428
1429 return OpFoldResult();
1430 }
1431
createCanonicalRankReducingExtractSliceOp(OpBuilder & b,Location loc,Value tensor,RankedTensorType targetType)1432 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
1433 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
1434 auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
1435 unsigned rank = rankedTensorType.getRank();
1436 auto shape = rankedTensorType.getShape();
1437 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1438 SmallVector<OpFoldResult> sizes;
1439 for (unsigned i = 0, e = rank; i < e; ++i) {
1440 OpFoldResult dim;
1441 if (rankedTensorType.isDynamicDim(i))
1442 dim = b.createOrFold<tensor::DimOp>(
1443 loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
1444 else
1445 dim = b.getIndexAttr(shape[i]);
1446 sizes.push_back(dim);
1447 }
1448 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1449 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1450 offsets, sizes, strides);
1451 }
1452
1453 //===----------------------------------------------------------------------===//
1454 // InsertSliceOp
1455 //===----------------------------------------------------------------------===//
1456
1457 // Build a InsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1458 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1459 Value dest, ArrayRef<OpFoldResult> offsets,
1460 ArrayRef<OpFoldResult> sizes,
1461 ArrayRef<OpFoldResult> strides,
1462 ArrayRef<NamedAttribute> attrs) {
1463 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1464 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1465 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1466 ShapedType::kDynamicStrideOrOffset);
1467 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1468 ShapedType::kDynamicSize);
1469 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1470 ShapedType::kDynamicStrideOrOffset);
1471 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1472 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1473 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1474 result.addAttributes(attrs);
1475 }
1476
1477 // Build a InsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1478 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1479 Value dest, ValueRange offsets, ValueRange sizes,
1480 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1481 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1482 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1483 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1484 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1485 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1486 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1487 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1488 }
1489
1490 /// Rank-reducing type verification for both InsertSliceOp and
1491 /// ParallelInsertSliceOp.
1492 static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType,ShapedType dstType,ArrayAttr staticOffsets,ArrayAttr staticSizes,ArrayAttr staticStrides,ShapedType * expectedType=nullptr)1493 verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
1494 ArrayAttr staticOffsets, ArrayAttr staticSizes,
1495 ArrayAttr staticStrides,
1496 ShapedType *expectedType = nullptr) {
1497 // insert_slice is the inverse of extract_slice, use the same type inference.
1498 RankedTensorType expected = ExtractSliceOp::inferResultType(
1499 dstType, extractFromI64ArrayAttr(staticOffsets),
1500 extractFromI64ArrayAttr(staticSizes),
1501 extractFromI64ArrayAttr(staticStrides));
1502 if (expectedType)
1503 *expectedType = expected;
1504 return isRankReducedType(expected, srcType);
1505 }
1506
1507 /// Verifier for InsertSliceOp.
verify()1508 LogicalResult InsertSliceOp::verify() {
1509 ShapedType expectedType;
1510 SliceVerificationResult result =
1511 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
1512 getStaticSizes(), getStaticStrides(), &expectedType);
1513 return produceSliceErrorMsg(result, *this, expectedType);
1514 }
1515
1516 /// If we have two consecutive InsertSliceOp writing to the same slice, we
1517 /// can mutate the second InsertSliceOp's destination to the first one's.
1518 /// This works similarly when the second op is a ParallelInsertSliceOp.
1519 ///
1520 /// Example:
1521 ///
1522 /// ```mlir
1523 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1524 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1525 /// ```
1526 ///
1527 /// folds into:
1528 ///
1529 /// ```mlir
1530 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1531 /// ```
1532 ///
1533 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1534 template <typename InsertOpTy>
foldInsertAfterInsertSlice(InsertOpTy insertOp)1535 static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
1536 auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
1537
1538 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1539 if (!prevInsertOp ||
1540 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
1541 !prevInsertOp.isSameAs(insertOp, isSame))
1542 return failure();
1543
1544 insertOp.getDestMutable().assign(prevInsertOp.getDest());
1545 return success();
1546 }
1547
1548 /// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
1549 /// type varies though so we wrap it in a FailureOr.
1550 ///
1551 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1552 template <typename InsertOpTy>
foldInsertOp(InsertOpTy insertOp,ArrayRef<Attribute>)1553 FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
1554 if (insertOp.getSourceType().hasStaticShape() &&
1555 insertOp.getDestType().hasStaticShape() &&
1556 insertOp.getSourceType() == insertOp.getDestType() &&
1557 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
1558 insertOp, insertOp.getDestType())))
1559 return static_cast<OpFoldResult>(insertOp.getSource());
1560 if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
1561 // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
1562 // return OpFoldResult().
1563 if (std::is_same<InsertOpTy, InsertSliceOp>::value)
1564 return static_cast<OpFoldResult>(insertOp->getResult(0));
1565 else
1566 return OpFoldResult();
1567 }
1568 return failure();
1569 }
1570
fold(ArrayRef<Attribute> operands)1571 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
1572 auto maybeOpFoldResult = foldInsertOp(*this, operands);
1573 return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
1574 }
1575
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)1576 LogicalResult InsertSliceOp::reifyResultShapes(
1577 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1578 reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1579 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1580 reifiedReturnShapes[0][dim] =
1581 builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
1582 }
1583 return success();
1584 }
1585
1586 namespace {
1587 /// Pattern to rewrite a insert_slice op with constant arguments.
1588 ///
1589 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1590 template <typename InsertOpTy>
1591 class InsertSliceOpConstantArgumentFolder final
1592 : public OpRewritePattern<InsertOpTy> {
1593 public:
1594 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1595
matchAndRewrite(InsertOpTy insertSliceOp,PatternRewriter & rewriter) const1596 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1597 PatternRewriter &rewriter) const override {
1598 // No constant operand, just return.
1599 if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1600 return matchPattern(operand, matchConstantIndex());
1601 }))
1602 return failure();
1603
1604 // At least one of offsets/sizes/strides is a new constant.
1605 // Form the new list of operands and constant attributes from the
1606 // existing.
1607 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1608 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1609 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1610 canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1611 canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1612 canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1613
1614 // Create the new op in canonical form.
1615 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
1616 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
1617 mixedOffsets, mixedSizes, mixedStrides);
1618 Value toInsert = insertSliceOp.getSource();
1619 if (sourceType != insertSliceOp.getSourceType()) {
1620 OpBuilder::InsertionGuard g(rewriter);
1621 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1622 // the the insertion point is just before the ParallelCombiningOp in the
1623 // parallel case.
1624 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1625 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1626 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1627 sourceType, toInsert);
1628 }
1629 rewriter.replaceOpWithNewOp<InsertOpTy>(
1630 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
1631 mixedSizes, mixedStrides);
1632 return success();
1633 }
1634 };
1635
1636 /// Fold tensor_casts with insert_slice operations. If the source or destination
1637 /// tensor is a tensor_cast that removes static type information, the cast is
1638 /// folded into the insert_slice operation. E.g.:
1639 ///
1640 /// ```mlir
1641 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1642 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1643 /// ```
1644 ///
1645 /// folds into:
1646 ///
1647 /// ```mlir
1648 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1649 /// ```
1650 ///
1651 /// Note: When folding a cast on the destination tensor, the result of the
1652 /// insert_slice operation is casted to ensure that the type of the result did
1653 /// not change.
1654 ///
1655 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1656 template <typename InsertOpTy>
1657 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
1658 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1659
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpCastFolder1660 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1661 PatternRewriter &rewriter) const override {
1662 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1663 return matchPattern(operand, matchConstantIndex());
1664 }))
1665 return failure();
1666
1667 auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1668 auto castOp = v.getDefiningOp<tensor::CastOp>();
1669 if (!castOp || !canFoldIntoConsumerOp(castOp))
1670 return llvm::None;
1671 return castOp.getSource();
1672 };
1673 Optional<Value> sourceCastSource =
1674 getSourceOfCastOp(insertSliceOp.getSource());
1675 Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
1676 if (!sourceCastSource && !destCastSource)
1677 return failure();
1678
1679 auto src =
1680 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
1681 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
1682 auto srcType = src.getType().template cast<ShapedType>();
1683 auto dstType = dst.getType().template cast<ShapedType>();
1684 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
1685 insertSliceOp.getStaticSizes(),
1686 insertSliceOp.getStaticStrides()) !=
1687 SliceVerificationResult::Success)
1688 return failure();
1689
1690 Operation *replacement = rewriter.create<InsertOpTy>(
1691 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1692 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
1693
1694 // In the parallel case there is no result and so nothing to cast.
1695 bool isParallelInsert =
1696 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
1697 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
1698 replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1699 insertSliceOp.getDestType(),
1700 replacement->getResult(0));
1701 }
1702 rewriter.replaceOp(insertSliceOp, replacement->getResults());
1703 return success();
1704 }
1705 };
1706
1707 /// If additional static type information can be deduced from a insert_slice's
1708 /// size operands, insert an explicit cast of the op's source operand. This
1709 /// enables other canonicalization patterns that are matching for tensor_cast
1710 /// ops such as `ForOpTensorCastFolder` in SCF.
1711 ///
1712 /// Example:
1713 ///
1714 /// ```mlir
1715 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1716 /// : tensor<?x?xf32> into ...
1717 /// ```
1718 ///
1719 /// folds into:
1720 ///
1721 /// ```mlir
1722 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1723 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1724 /// : tensor<64x64xf32> into ...
1725 /// ```
1726 ///
1727 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
1728 template <typename InsertOpTy>
1729 struct InsertSliceOpSourceCastInserter final
1730 : public OpRewritePattern<InsertOpTy> {
1731 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1732
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpSourceCastInserter1733 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1734 PatternRewriter &rewriter) const override {
1735 RankedTensorType srcType = insertSliceOp.getSourceType();
1736 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
1737 return failure();
1738 SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1739 srcType.getShape().end());
1740 for (int64_t i = 0; i < srcType.getRank(); ++i) {
1741 if (Optional<int64_t> constInt =
1742 getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1743 newSrcShape[i] = *constInt;
1744 }
1745
1746 RankedTensorType newSrcType =
1747 RankedTensorType::get(newSrcShape, srcType.getElementType());
1748 if (srcType == newSrcType ||
1749 !preservesStaticInformation(srcType, newSrcType) ||
1750 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1751 return failure();
1752
1753 // newSrcType is:
1754 // 1) Different from srcType.
1755 // 2) "More static" than srcType.
1756 // 3) Cast-compatible with srcType.
1757 // Insert the cast.
1758 OpBuilder::InsertionGuard g(rewriter);
1759 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1760 // the the insertion point is just before the ParallelCombiningOp in the
1761 // parallel case.
1762 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1763 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1764 Value cast = rewriter.create<tensor::CastOp>(
1765 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
1766 rewriter.replaceOpWithNewOp<InsertOpTy>(
1767 insertSliceOp, cast, insertSliceOp.getDest(),
1768 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1769 insertSliceOp.getMixedStrides());
1770 cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
1771 return success();
1772 }
1773 };
1774 } // namespace
1775
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1776 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1777 MLIRContext *context) {
1778 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
1779 InsertSliceOpCastFolder<InsertSliceOp>,
1780 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
1781 }
1782
createCanonicalRankReducingInsertSliceOp(OpBuilder & b,Location loc,Value tensor,Value dest)1783 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
1784 Location loc,
1785 Value tensor,
1786 Value dest) {
1787 auto rankedTensorType = dest.getType().cast<RankedTensorType>();
1788 unsigned rank = rankedTensorType.getRank();
1789 auto shape = rankedTensorType.getShape();
1790 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1791 SmallVector<OpFoldResult> sizes;
1792 for (unsigned i = 0, e = rank; i < e; ++i) {
1793 OpFoldResult dim;
1794 if (rankedTensorType.isDynamicDim(i))
1795 dim = b.createOrFold<tensor::DimOp>(
1796 loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
1797 else
1798 dim = b.getIndexAttr(shape[i]);
1799 sizes.push_back(dim);
1800 }
1801 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1802 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1803 sizes, strides);
1804 }
1805
1806 //===----------------------------------------------------------------------===//
1807 // PadOp
1808 //===----------------------------------------------------------------------===//
1809
1810 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1811 // supports optional types.
printInferType(OpAsmPrinter & printer,Operation * op,Value optOperand,Type typeToInfer,Type typeToInferFrom)1812 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1813 Type typeToInfer, Type typeToInferFrom) {}
1814
parseInferType(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> optOperand,Type & typeToInfer,Type typeToInferFrom)1815 ParseResult parseInferType(OpAsmParser &parser,
1816 Optional<OpAsmParser::UnresolvedOperand> optOperand,
1817 Type &typeToInfer, Type typeToInferFrom) {
1818 if (optOperand)
1819 typeToInfer = typeToInferFrom;
1820 return success();
1821 }
1822
verify()1823 LogicalResult PadOp::verify() {
1824 auto sourceType = getSource().getType().cast<RankedTensorType>();
1825 auto resultType = getResult().getType().cast<RankedTensorType>();
1826 auto expectedType = PadOp::inferResultType(
1827 sourceType, extractFromI64ArrayAttr(getStaticLow()),
1828 extractFromI64ArrayAttr(getStaticHigh()));
1829 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1830 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1831 continue;
1832 if (expectedType.isDynamicDim(i))
1833 continue;
1834 return emitError("specified type ")
1835 << resultType << " does not match the inferred type "
1836 << expectedType;
1837 }
1838
1839 return success();
1840 }
1841
verifyRegions()1842 LogicalResult PadOp::verifyRegions() {
1843 auto ®ion = getRegion();
1844 unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
1845 Block &block = region.front();
1846 if (block.getNumArguments() != rank)
1847 return emitError("expected the block to have ") << rank << " arguments";
1848
1849 // Note: the number and type of yield values are checked in the YieldOp.
1850 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1851 if (!en.value().isIndex())
1852 return emitOpError("expected block argument ")
1853 << (en.index() + 1) << " to be an index";
1854 }
1855
1856 // Ensure that the region yields an element of the right type.
1857 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
1858 if (yieldOp.getValue().getType() !=
1859 getType().cast<ShapedType>().getElementType())
1860 return emitOpError("expected yield type to match shape element type");
1861
1862 return success();
1863 }
1864
inferResultType(RankedTensorType sourceType,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ArrayRef<int64_t> resultShape)1865 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1866 ArrayRef<int64_t> staticLow,
1867 ArrayRef<int64_t> staticHigh,
1868 ArrayRef<int64_t> resultShape) {
1869 unsigned rank = sourceType.getRank();
1870 assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1871 assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1872 assert((resultShape.empty() || resultShape.size() == rank) &&
1873 "unexpected resultShape size mismatch");
1874
1875 SmallVector<int64_t, 4> inferredShape;
1876 for (auto i : llvm::seq<unsigned>(0, rank)) {
1877 if (sourceType.isDynamicDim(i) ||
1878 staticLow[i] == ShapedType::kDynamicSize ||
1879 staticHigh[i] == ShapedType::kDynamicSize) {
1880 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1881 : resultShape[i]);
1882 } else {
1883 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1884 assert((resultShape.empty() || size == resultShape[i] ||
1885 resultShape[i] == ShapedType::kDynamicSize) &&
1886 "mismatch between inferred shape and result shape");
1887 inferredShape.push_back(size);
1888 }
1889 }
1890
1891 return RankedTensorType::get(inferredShape, sourceType.getElementType());
1892 }
1893
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1894 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1895 ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
1896 ValueRange low, ValueRange high, bool nofold,
1897 ArrayRef<NamedAttribute> attrs) {
1898 auto sourceType = source.getType().cast<RankedTensorType>();
1899 auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1900 build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1901 b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1902 result.addAttributes(attrs);
1903 }
1904
build(OpBuilder & b,OperationState & result,Value source,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1905 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1906 ValueRange low, ValueRange high, bool nofold,
1907 ArrayRef<NamedAttribute> attrs) {
1908 auto sourceType = source.getType().cast<RankedTensorType>();
1909 unsigned rank = sourceType.getRank();
1910 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1911 build(b, result, source, staticVector, staticVector, low, high, nofold,
1912 attrs);
1913 }
1914
build(OpBuilder & b,OperationState & result,Type resultType,Value source,ArrayRef<OpFoldResult> low,ArrayRef<OpFoldResult> high,bool nofold,ArrayRef<NamedAttribute> attrs)1915 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1916 Value source, ArrayRef<OpFoldResult> low,
1917 ArrayRef<OpFoldResult> high, bool nofold,
1918 ArrayRef<NamedAttribute> attrs) {
1919 assert(resultType.isa<RankedTensorType>());
1920 auto sourceType = source.getType().cast<RankedTensorType>();
1921 SmallVector<Value, 4> dynamicLow, dynamicHigh;
1922 SmallVector<int64_t, 4> staticLow, staticHigh;
1923 // staticLow and staticHigh have full information of the padding config.
1924 // This will grow staticLow and staticHigh with 1 value. If the config is
1925 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1926 // value as well.
1927 dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1928 ShapedType::kDynamicSize);
1929 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1930 ShapedType::kDynamicSize);
1931 if (!resultType) {
1932 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1933 }
1934 build(b, result, resultType, source, dynamicLow, dynamicHigh,
1935 b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1936 nofold ? b.getUnitAttr() : UnitAttr());
1937 result.addAttributes(attrs);
1938 }
1939
getPaddedDims()1940 llvm::SmallBitVector PadOp::getPaddedDims() {
1941 llvm::SmallBitVector paddedDims(getSourceType().getRank());
1942 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1943 for (const auto &en : enumerate(paddingWidths))
1944 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
1945 paddedDims.set(en.index());
1946 };
1947 extractPaddedDims(getMixedLowPad());
1948 extractPaddedDims(getMixedHighPad());
1949 return paddedDims;
1950 }
1951
1952 namespace {
1953 // Folds tensor.pad when padding is static zeros and the attribute
1954 // doesn't request otherwise.
1955 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
1956 using OpRewritePattern<PadOp>::OpRewritePattern;
1957
matchAndRewrite__anon3fb9f79f1611::FoldStaticZeroPadding1958 LogicalResult matchAndRewrite(PadOp padTensorOp,
1959 PatternRewriter &rewriter) const override {
1960 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1961 return failure();
1962 if (padTensorOp.getNofold())
1963 return failure();
1964 rewriter.replaceOpWithNewOp<tensor::CastOp>(
1965 padTensorOp, padTensorOp.getResult().getType(),
1966 padTensorOp.getSource());
1967 return success();
1968 }
1969 };
1970
1971 // Fold CastOp into PadOp when adding static information.
1972 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
1973 using OpRewritePattern<PadOp>::OpRewritePattern;
1974
matchAndRewrite__anon3fb9f79f1611::FoldSourceTensorCast1975 LogicalResult matchAndRewrite(PadOp padTensorOp,
1976 PatternRewriter &rewriter) const override {
1977 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
1978 if (!tensor::canFoldIntoConsumerOp(castOp))
1979 return failure();
1980
1981 auto newResultType = PadOp::inferResultType(
1982 castOp.getSource().getType().cast<RankedTensorType>(),
1983 extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
1984 extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
1985 padTensorOp.getResultType().getShape());
1986
1987 if (newResultType == padTensorOp.getResultType()) {
1988 rewriter.updateRootInPlace(padTensorOp, [&]() {
1989 padTensorOp.getSourceMutable().assign(castOp.getSource());
1990 });
1991 } else {
1992 auto newOp = rewriter.create<PadOp>(
1993 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
1994 padTensorOp.getLow(), padTensorOp.getHigh(),
1995 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
1996 padTensorOp.getNofold());
1997 BlockAndValueMapping mapper;
1998 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1999
2000 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2001 padTensorOp, padTensorOp.getResultType(), newOp);
2002 }
2003 return success();
2004 }
2005 };
2006
2007 // Fold CastOp using the result of PadOp back into the latter if it adds
2008 // static information.
2009 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2010 using OpRewritePattern<PadOp>::OpRewritePattern;
2011
matchAndRewrite__anon3fb9f79f1611::FoldTargetTensorCast2012 LogicalResult matchAndRewrite(PadOp padTensorOp,
2013 PatternRewriter &rewriter) const override {
2014 if (!padTensorOp.getResult().hasOneUse())
2015 return failure();
2016 auto tensorCastOp =
2017 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2018 if (!tensorCastOp)
2019 return failure();
2020 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2021 tensorCastOp.getDest().getType()))
2022 return failure();
2023
2024 auto replacementOp = rewriter.create<PadOp>(
2025 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2026 padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
2027 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2028 padTensorOp.getNofold());
2029 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2030
2031 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2032 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2033 return success();
2034 }
2035 };
2036
2037 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2038 /// different dimensions. The pattern applies if the following preconditions
2039 /// hold:
2040 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
2041 /// 2) the tensor::ExtractSliceOps have only unit-strides,
2042 /// 3) the tensor::PadOps perform only high-padding,
2043 /// 4) the tensor::PadOps have the same constant padding value,
2044 /// 5) the tensor::PadOps do not have common padding dimensions,
2045 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2046 /// zero-offset for every dimension.
2047 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
2048 /// padded source dimensions.
2049 ///
2050 /// Example:
2051 ///
2052 /// ```mlir
2053 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2054 /// : tensor<64x64xf32> to tensor<?x64xf32>
2055 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2056 /// } : tensor<?x64xf32> to tensor<8x64xf32>
2057 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2058 /// : tensor<8x64xf32> to tensor<8x?xf32>
2059 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2060 /// } : tensor<8x?xf32> to tensor<8x4xf32>
2061 /// ```
2062 ///
2063 /// folds into:
2064 ///
2065 /// ```mlir
2066 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2067 /// : tensor<64x64xf32> to tensor<?x?xf32>
2068 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2069 /// } : tensor<?x?xf32> to tensor<8x4xf32>
2070 /// ```
2071 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2072 using OpRewritePattern<PadOp>::OpRewritePattern;
2073
matchAndRewrite__anon3fb9f79f1611::FoldOrthogonalPaddings2074 LogicalResult matchAndRewrite(PadOp padOp,
2075 PatternRewriter &rewriter) const override {
2076 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2077 if (!innerSliceOp)
2078 return failure();
2079 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2080 if (!outerPadOp || outerPadOp.getNofold())
2081 return failure();
2082 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2083 if (!outerSliceOp)
2084 return failure();
2085
2086 // 1) Fail if the chain is rank-reducing.
2087 int64_t rank = padOp.getSourceType().getRank();
2088 if (outerSliceOp.getSourceType().getRank() != rank) {
2089 return rewriter.notifyMatchFailure(padOp,
2090 "cannot fold rank-reducing chain");
2091 }
2092
2093 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2094 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2095 return rewriter.notifyMatchFailure(
2096 padOp, "cannot fold non-unit stride ExtractSliceOps");
2097 }
2098
2099 // 3) Fail if the tensor::PadOps have non-zero low padding.
2100 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2101 return rewriter.notifyMatchFailure(padOp,
2102 "cannot fold PadOps with low padding");
2103 }
2104
2105 // 4) Fail if the tensor::PadOps padding values do not match.
2106 Attribute innerAttr, outerAttr;
2107 Value innerValue = padOp.getConstantPaddingValue();
2108 Value outerValue = outerPadOp.getConstantPaddingValue();
2109 if (!innerValue || !outerValue ||
2110 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2111 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2112 innerAttr != outerAttr) {
2113 return rewriter.notifyMatchFailure(
2114 padOp, "cannot fold PadOps with different padding values");
2115 }
2116
2117 // 5) Fail if a dimension is padded by both tensor::PadOps.
2118 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2119 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2120 if (innerDims.anyCommon(outerDims)) {
2121 return rewriter.notifyMatchFailure(
2122 padOp, "cannot fold PadOps with common padding dimensions");
2123 }
2124
2125 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2126 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2127 // for every dimension, and use the offset the other pair. Fail if no
2128 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2129 // exists.
2130 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2131 for (auto &en : enumerate(newOffsets)) {
2132 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2133 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2134 if (!innerDims.test(en.index()) &&
2135 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2136 en.value() = outerOffset;
2137 continue;
2138 }
2139 if (!outerDims.test(en.index()) &&
2140 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2141 en.value() = innerOffset;
2142 continue;
2143 }
2144 return rewriter.notifyMatchFailure(
2145 padOp, "cannot find zero-offset and zero-padding pair");
2146 }
2147
2148 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2149 // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2150 // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2151 // does not match the size of the padded dimension. Otherwise, take the size
2152 // of the inner tensor::ExtractSliceOp.
2153 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2154 for (auto &en : enumerate(newSizes)) {
2155 if (!outerDims.test(en.index()))
2156 continue;
2157 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2158 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2159 assert(!ShapedType::isDynamic(sourceSize) &&
2160 "expected padded dimension to have a static size");
2161 if (getConstantIntValue(sliceSize) != sourceSize) {
2162 return rewriter.notifyMatchFailure(
2163 padOp, "cannot fold since the inner ExtractSliceOp size does not "
2164 "match the size of the outer padding");
2165 }
2166 en.value() = outerSliceOp.getMixedSizes()[en.index()];
2167 }
2168
2169 // Combine the high paddings of the two tensor::PadOps.
2170 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2171 for (auto &en : enumerate(newHighPad)) {
2172 if (innerDims.test(en.index()))
2173 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2174 if (outerDims.test(en.index()))
2175 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2176 }
2177
2178 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2179 // two paddings in one step.
2180 auto newSliceOp = rewriter.create<ExtractSliceOp>(
2181 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2182 innerSliceOp.getMixedStrides());
2183 auto newPadOp = rewriter.create<PadOp>(
2184 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2185 padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2186 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2187 newPadOp.getRegion().begin());
2188 rewriter.replaceOp(padOp, newPadOp.getResult());
2189 return success();
2190 }
2191 };
2192
2193 } // namespace
2194
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2195 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196 MLIRContext *context) {
2197 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2198 FoldOrthogonalPaddings>(context);
2199 }
2200
2201 /// Return the padding value of the PadOp if it constant. In this context,
2202 /// "constant" means an actual constant or "defined outside of the block".
2203 ///
2204 /// Values are considered constant in three cases:
2205 /// - A ConstantLike value.
2206 /// - A basic block argument from a different block.
2207 /// - A value defined outside of the block.
2208 ///
2209 /// If the padding value is not constant, an empty Value is returned.
getConstantPaddingValue()2210 Value PadOp::getConstantPaddingValue() {
2211 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2212 if (!yieldOp)
2213 return {};
2214 Value padValue = yieldOp.getValue();
2215 // Check if yield value is a constant.
2216 if (matchPattern(padValue, m_Constant()))
2217 return padValue;
2218 // Check if yield value is defined inside the PadOp block.
2219 if (padValue.getParentBlock() == &getRegion().front())
2220 return {};
2221 // Else: Yield value defined outside of the PadOp block.
2222 return padValue;
2223 }
2224
fold(ArrayRef<Attribute>)2225 OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
2226 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
2227 !getNofold())
2228 return getSource();
2229 return {};
2230 }
2231
2232 //===----------------------------------------------------------------------===//
2233 // ParallelInsertSliceOp
2234 //===----------------------------------------------------------------------===//
2235
getTiedOpResult()2236 OpResult ParallelInsertSliceOp::getTiedOpResult() {
2237 ParallelCombiningOpInterface parallelCombiningParent =
2238 getParallelCombiningParent();
2239 for (const auto &it :
2240 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
2241 Operation &nextOp = it.value();
2242 if (&nextOp == getOperation())
2243 return parallelCombiningParent.getParentResult(it.index());
2244 }
2245 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
2246 }
2247
2248 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)2249 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2250 Value source, Value dest,
2251 ArrayRef<OpFoldResult> offsets,
2252 ArrayRef<OpFoldResult> sizes,
2253 ArrayRef<OpFoldResult> strides,
2254 ArrayRef<NamedAttribute> attrs) {
2255 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2256 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2257 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2258 ShapedType::kDynamicStrideOrOffset);
2259 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2260 ShapedType::kDynamicSize);
2261 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2262 ShapedType::kDynamicStrideOrOffset);
2263 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
2264 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
2265 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
2266 result.addAttributes(attrs);
2267 }
2268
2269 // Build a ParallelInsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2270 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2271 Value source, Value dest, ValueRange offsets,
2272 ValueRange sizes, ValueRange strides,
2273 ArrayRef<NamedAttribute> attrs) {
2274 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2275 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2276 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2277 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2278 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2279 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2280 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2281 }
2282
verify()2283 LogicalResult ParallelInsertSliceOp::verify() {
2284 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
2285 return this->emitError("expected ParallelCombiningOpInterface parent, got:")
2286 << *(getOperation()->getParentOp());
2287
2288 ShapedType expectedType;
2289 SliceVerificationResult result =
2290 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
2291 getStaticSizes(), getStaticStrides(), &expectedType);
2292 return produceSliceErrorMsg(result, *this, expectedType);
2293 }
2294
2295 namespace {
2296 /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
2297 class ParallelInsertSliceOpConstantArgumentFolder final
2298 : public OpRewritePattern<ParallelInsertSliceOp> {
2299 public:
2300 using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
2301
matchAndRewrite(ParallelInsertSliceOp insertSliceOp,PatternRewriter & rewriter) const2302 LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
2303 PatternRewriter &rewriter) const override {
2304 // No constant operand, just return.
2305 if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
2306 return matchPattern(operand, matchConstantIndex());
2307 }))
2308 return failure();
2309
2310 // At least one of offsets/sizes/strides is a new constant.
2311 // Form the new list of operands and constant attributes from the
2312 // existing.
2313 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2314 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2315 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2316 canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
2317 canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
2318 canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
2319
2320 // Create the new op in canonical form.
2321 auto sourceType =
2322 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
2323 insertSliceOp.getSourceType().getRank(),
2324 insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
2325 mixedStrides);
2326 Value toInsert = insertSliceOp.getSource();
2327 if (sourceType != insertSliceOp.getSourceType()) {
2328 OpBuilder::InsertionGuard g(rewriter);
2329 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2330 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2331 sourceType, toInsert);
2332 }
2333 rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
2334 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2335 mixedSizes, mixedStrides);
2336 return success();
2337 }
2338 };
2339 } // namespace
2340
2341 LogicalResult
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2342 ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
2343 SmallVectorImpl<OpFoldResult> &results) {
2344 return foldInsertOp(*this, operands);
2345 }
2346
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2347 void ParallelInsertSliceOp::getCanonicalizationPatterns(
2348 RewritePatternSet &results, MLIRContext *context) {
2349 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
2350 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
2351 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
2352 }
2353
2354 //===----------------------------------------------------------------------===//
2355 // SplatOp
2356 //===----------------------------------------------------------------------===//
2357
fold(ArrayRef<Attribute> operands)2358 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2359 auto constOperand = operands.front();
2360 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
2361 return {};
2362
2363 // SplatElementsAttr::get treats single value for second arg as being a splat.
2364 return SplatElementsAttr::get(getType(), {constOperand});
2365 }
2366
2367 //===----------------------------------------------------------------------===//
2368 // TableGen'd op method definitions
2369 //===----------------------------------------------------------------------===//
2370
2371 #define GET_OP_CLASSES
2372 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
2373