1 //===- ComposeSubView.cpp - Combining composed subview ops ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains patterns for combining composed subview ops (i.e. subview 10 // of a subview becomes a single subview). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 // Replaces a subview of a subview with a single subview. Only supports subview 28 // ops with static sizes and static strides of 1 (both static and dynamic 29 // offsets are supported). 30 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { 31 using OpRewritePattern::OpRewritePattern; 32 33 LogicalResult matchAndRewrite(memref::SubViewOp op, 34 PatternRewriter &rewriter) const override { 35 // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that 36 // produces the input of the op we're rewriting (for 'SubViewOp' the input 37 // is called the "source" value). We can only combine them if both 'op' and 38 // 'sourceOp' are 'SubViewOp'. 39 auto sourceOp = op.source().getDefiningOp<memref::SubViewOp>(); 40 if (!sourceOp) 41 return failure(); 42 43 // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the 44 // output memref that are statically known to be equal to 1. We do not 45 // allow 'sourceOp' to be a rank-reducing subview because then our two 46 // 'SubViewOp's would have different numbers of offset/size/stride 47 // parameters (just difficult to deal with, not impossible if we end up 48 // needing it). 49 if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) { 50 return failure(); 51 } 52 53 // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'. 54 SmallVector<OpFoldResult> offsets, sizes, strides; 55 56 // Because we only support input strides of 1, the output stride is also 57 // always 1. 58 if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { 59 Attribute attr = valueOrAttr.dyn_cast<Attribute>(); 60 return attr && attr.cast<IntegerAttr>().getInt() == 1; 61 })) { 62 strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(), 63 rewriter.getI64IntegerAttr(1)); 64 } else { 65 return failure(); 66 } 67 68 // The rules for calculating the new offsets and sizes are: 69 // * Multiple subview offsets for a given dimension compose additively. 70 // ("Offset by m" followed by "Offset by n" == "Offset by m + n") 71 // * Multiple sizes for a given dimension compose by taking the size of the 72 // final subview and ignoring the rest. ("Take m values" followed by "Take 73 // n values" == "Take n values") This size must also be the smallest one 74 // by definition (a subview needs to be the same size as or smaller than 75 // its source along each dimension; presumably subviews that are larger 76 // than their sources are disallowed by validation). 77 for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), 78 op.getMixedSizes())) { 79 auto opOffset = std::get<0>(it); 80 auto sourceOffset = std::get<1>(it); 81 auto opSize = std::get<2>(it); 82 83 // We only support static sizes. 84 if (opSize.is<Value>()) { 85 return failure(); 86 } 87 88 sizes.push_back(opSize); 89 Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(), 90 sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>(); 91 92 if (opOffsetAttr && sourceOffsetAttr) { 93 // If both offsets are static we can simply calculate the combined 94 // offset statically. 95 offsets.push_back(rewriter.getI64IntegerAttr( 96 opOffsetAttr.cast<IntegerAttr>().getInt() + 97 sourceOffsetAttr.cast<IntegerAttr>().getInt())); 98 } else { 99 // When either offset is dynamic, we must emit an additional affine 100 // transformation to add the two offsets together dynamically. 101 AffineExpr expr = rewriter.getAffineConstantExpr(0); 102 SmallVector<Value> affineApplyOperands; 103 for (auto valueOrAttr : {opOffset, sourceOffset}) { 104 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) { 105 expr = expr + attr.cast<IntegerAttr>().getInt(); 106 } else { 107 expr = 108 expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); 109 affineApplyOperands.push_back(valueOrAttr.get<Value>()); 110 } 111 } 112 113 AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); 114 Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map, 115 affineApplyOperands); 116 offsets.push_back(result); 117 } 118 } 119 120 // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any 121 // uses it can be removed by a (separate) dead code elimination pass. 122 rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.source(), 123 offsets, sizes, strides); 124 return success(); 125 } 126 }; 127 128 } // namespace 129 130 void mlir::memref::populateComposeSubViewPatterns( 131 RewritePatternSet &patterns, MLIRContext *context) { 132 patterns.insert<ComposeSubViewOpPattern>(context); 133 } 134