1 //===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for combining composed subview ops (i.e. subview
10 // of a subview becomes a single subview).
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 using namespace mlir;
24
25 namespace {
26
27 // Replaces a subview of a subview with a single subview. Only supports subview
28 // ops with static sizes and static strides of 1 (both static and dynamic
29 // offsets are supported).
30 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
31 using OpRewritePattern::OpRewritePattern;
32
matchAndRewrite__anond24980220111::ComposeSubViewOpPattern33 LogicalResult matchAndRewrite(memref::SubViewOp op,
34 PatternRewriter &rewriter) const override {
35 // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
36 // produces the input of the op we're rewriting (for 'SubViewOp' the input
37 // is called the "source" value). We can only combine them if both 'op' and
38 // 'sourceOp' are 'SubViewOp'.
39 auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
40 if (!sourceOp)
41 return failure();
42
43 // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
44 // output memref that are statically known to be equal to 1. We do not
45 // allow 'sourceOp' to be a rank-reducing subview because then our two
46 // 'SubViewOp's would have different numbers of offset/size/stride
47 // parameters (just difficult to deal with, not impossible if we end up
48 // needing it).
49 if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
50 return failure();
51 }
52
53 // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
54 SmallVector<OpFoldResult> offsets, sizes, strides;
55
56 // Because we only support input strides of 1, the output stride is also
57 // always 1.
58 if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
59 Attribute attr = valueOrAttr.dyn_cast<Attribute>();
60 return attr && attr.cast<IntegerAttr>().getInt() == 1;
61 })) {
62 strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
63 rewriter.getI64IntegerAttr(1));
64 } else {
65 return failure();
66 }
67
68 // The rules for calculating the new offsets and sizes are:
69 // * Multiple subview offsets for a given dimension compose additively.
70 // ("Offset by m" followed by "Offset by n" == "Offset by m + n")
71 // * Multiple sizes for a given dimension compose by taking the size of the
72 // final subview and ignoring the rest. ("Take m values" followed by "Take
73 // n values" == "Take n values") This size must also be the smallest one
74 // by definition (a subview needs to be the same size as or smaller than
75 // its source along each dimension; presumably subviews that are larger
76 // than their sources are disallowed by validation).
77 for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
78 op.getMixedSizes())) {
79 auto opOffset = std::get<0>(it);
80 auto sourceOffset = std::get<1>(it);
81 auto opSize = std::get<2>(it);
82
83 // We only support static sizes.
84 if (opSize.is<Value>()) {
85 return failure();
86 }
87
88 sizes.push_back(opSize);
89 Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
90 sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
91
92 if (opOffsetAttr && sourceOffsetAttr) {
93 // If both offsets are static we can simply calculate the combined
94 // offset statically.
95 offsets.push_back(rewriter.getI64IntegerAttr(
96 opOffsetAttr.cast<IntegerAttr>().getInt() +
97 sourceOffsetAttr.cast<IntegerAttr>().getInt()));
98 } else {
99 // When either offset is dynamic, we must emit an additional affine
100 // transformation to add the two offsets together dynamically.
101 AffineExpr expr = rewriter.getAffineConstantExpr(0);
102 SmallVector<Value> affineApplyOperands;
103 for (auto valueOrAttr : {opOffset, sourceOffset}) {
104 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
105 expr = expr + attr.cast<IntegerAttr>().getInt();
106 } else {
107 expr =
108 expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
109 affineApplyOperands.push_back(valueOrAttr.get<Value>());
110 }
111 }
112
113 AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
114 Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map,
115 affineApplyOperands);
116 offsets.push_back(result);
117 }
118 }
119
120 // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
121 // uses it can be removed by a (separate) dead code elimination pass.
122 rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
123 offsets, sizes, strides);
124 return success();
125 }
126 };
127
128 } // namespace
129
populateComposeSubViewPatterns(RewritePatternSet & patterns,MLIRContext * context)130 void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
131 MLIRContext *context) {
132 patterns.add<ComposeSubViewOpPattern>(context);
133 }
134