17d0426ddSRiver Riddle //===- ComposeSubView.cpp - Combining composed subview ops ----------------===// 27d0426ddSRiver Riddle // 37d0426ddSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47d0426ddSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 57d0426ddSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67d0426ddSRiver Riddle // 77d0426ddSRiver Riddle //===----------------------------------------------------------------------===// 87d0426ddSRiver Riddle // 97d0426ddSRiver Riddle // This file contains patterns for combining composed subview ops (i.e. subview 107d0426ddSRiver Riddle // of a subview becomes a single subview). 117d0426ddSRiver Riddle // 127d0426ddSRiver Riddle //===----------------------------------------------------------------------===// 137d0426ddSRiver Riddle 147d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" 157d0426ddSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 167d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 177d0426ddSRiver Riddle #include "mlir/IR/BuiltinAttributes.h" 187d0426ddSRiver Riddle #include "mlir/IR/OpDefinition.h" 197d0426ddSRiver Riddle #include "mlir/IR/PatternMatch.h" 207d0426ddSRiver Riddle #include "mlir/Transforms/DialectConversion.h" 217d0426ddSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 227d0426ddSRiver Riddle 237d0426ddSRiver Riddle using namespace mlir; 247d0426ddSRiver Riddle 257d0426ddSRiver Riddle namespace { 267d0426ddSRiver Riddle 277d0426ddSRiver Riddle // Replaces a subview of a subview with a single subview. Only supports subview 287d0426ddSRiver Riddle // ops with static sizes and static strides of 1 (both static and dynamic 297d0426ddSRiver Riddle // offsets are supported). 307d0426ddSRiver Riddle struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { 317d0426ddSRiver Riddle using OpRewritePattern::OpRewritePattern; 327d0426ddSRiver Riddle 337d0426ddSRiver Riddle LogicalResult matchAndRewrite(memref::SubViewOp op, 347d0426ddSRiver Riddle PatternRewriter &rewriter) const override { 357d0426ddSRiver Riddle // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that 367d0426ddSRiver Riddle // produces the input of the op we're rewriting (for 'SubViewOp' the input 377d0426ddSRiver Riddle // is called the "source" value). We can only combine them if both 'op' and 387d0426ddSRiver Riddle // 'sourceOp' are 'SubViewOp'. 397d0426ddSRiver Riddle auto sourceOp = op.source().getDefiningOp<memref::SubViewOp>(); 407d0426ddSRiver Riddle if (!sourceOp) 417d0426ddSRiver Riddle return failure(); 427d0426ddSRiver Riddle 437d0426ddSRiver Riddle // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the 447d0426ddSRiver Riddle // output memref that are statically known to be equal to 1. We do not 457d0426ddSRiver Riddle // allow 'sourceOp' to be a rank-reducing subview because then our two 467d0426ddSRiver Riddle // 'SubViewOp's would have different numbers of offset/size/stride 477d0426ddSRiver Riddle // parameters (just difficult to deal with, not impossible if we end up 487d0426ddSRiver Riddle // needing it). 497d0426ddSRiver Riddle if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) { 507d0426ddSRiver Riddle return failure(); 517d0426ddSRiver Riddle } 527d0426ddSRiver Riddle 537d0426ddSRiver Riddle // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'. 547d0426ddSRiver Riddle SmallVector<OpFoldResult> offsets, sizes, strides; 557d0426ddSRiver Riddle 567d0426ddSRiver Riddle // Because we only support input strides of 1, the output stride is also 577d0426ddSRiver Riddle // always 1. 587d0426ddSRiver Riddle if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { 597d0426ddSRiver Riddle Attribute attr = valueOrAttr.dyn_cast<Attribute>(); 607d0426ddSRiver Riddle return attr && attr.cast<IntegerAttr>().getInt() == 1; 617d0426ddSRiver Riddle })) { 627d0426ddSRiver Riddle strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(), 637d0426ddSRiver Riddle rewriter.getI64IntegerAttr(1)); 647d0426ddSRiver Riddle } else { 657d0426ddSRiver Riddle return failure(); 667d0426ddSRiver Riddle } 677d0426ddSRiver Riddle 687d0426ddSRiver Riddle // The rules for calculating the new offsets and sizes are: 697d0426ddSRiver Riddle // * Multiple subview offsets for a given dimension compose additively. 707d0426ddSRiver Riddle // ("Offset by m" followed by "Offset by n" == "Offset by m + n") 717d0426ddSRiver Riddle // * Multiple sizes for a given dimension compose by taking the size of the 727d0426ddSRiver Riddle // final subview and ignoring the rest. ("Take m values" followed by "Take 737d0426ddSRiver Riddle // n values" == "Take n values") This size must also be the smallest one 747d0426ddSRiver Riddle // by definition (a subview needs to be the same size as or smaller than 757d0426ddSRiver Riddle // its source along each dimension; presumably subviews that are larger 767d0426ddSRiver Riddle // than their sources are disallowed by validation). 777d0426ddSRiver Riddle for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), 787d0426ddSRiver Riddle op.getMixedSizes())) { 797d0426ddSRiver Riddle auto opOffset = std::get<0>(it); 807d0426ddSRiver Riddle auto sourceOffset = std::get<1>(it); 817d0426ddSRiver Riddle auto opSize = std::get<2>(it); 827d0426ddSRiver Riddle 837d0426ddSRiver Riddle // We only support static sizes. 847d0426ddSRiver Riddle if (opSize.is<Value>()) { 857d0426ddSRiver Riddle return failure(); 867d0426ddSRiver Riddle } 877d0426ddSRiver Riddle 887d0426ddSRiver Riddle sizes.push_back(opSize); 897d0426ddSRiver Riddle Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(), 907d0426ddSRiver Riddle sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>(); 917d0426ddSRiver Riddle 927d0426ddSRiver Riddle if (opOffsetAttr && sourceOffsetAttr) { 937d0426ddSRiver Riddle // If both offsets are static we can simply calculate the combined 947d0426ddSRiver Riddle // offset statically. 957d0426ddSRiver Riddle offsets.push_back(rewriter.getI64IntegerAttr( 967d0426ddSRiver Riddle opOffsetAttr.cast<IntegerAttr>().getInt() + 977d0426ddSRiver Riddle sourceOffsetAttr.cast<IntegerAttr>().getInt())); 987d0426ddSRiver Riddle } else { 997d0426ddSRiver Riddle // When either offset is dynamic, we must emit an additional affine 1007d0426ddSRiver Riddle // transformation to add the two offsets together dynamically. 1017d0426ddSRiver Riddle AffineExpr expr = rewriter.getAffineConstantExpr(0); 1027d0426ddSRiver Riddle SmallVector<Value> affineApplyOperands; 1037d0426ddSRiver Riddle for (auto valueOrAttr : {opOffset, sourceOffset}) { 1047d0426ddSRiver Riddle if (auto attr = valueOrAttr.dyn_cast<Attribute>()) { 1057d0426ddSRiver Riddle expr = expr + attr.cast<IntegerAttr>().getInt(); 1067d0426ddSRiver Riddle } else { 1077d0426ddSRiver Riddle expr = 1087d0426ddSRiver Riddle expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); 1097d0426ddSRiver Riddle affineApplyOperands.push_back(valueOrAttr.get<Value>()); 1107d0426ddSRiver Riddle } 1117d0426ddSRiver Riddle } 1127d0426ddSRiver Riddle 1137d0426ddSRiver Riddle AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); 1147d0426ddSRiver Riddle Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map, 1157d0426ddSRiver Riddle affineApplyOperands); 1167d0426ddSRiver Riddle offsets.push_back(result); 1177d0426ddSRiver Riddle } 1187d0426ddSRiver Riddle } 1197d0426ddSRiver Riddle 1207d0426ddSRiver Riddle // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any 1217d0426ddSRiver Riddle // uses it can be removed by a (separate) dead code elimination pass. 1227d0426ddSRiver Riddle rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.source(), 1237d0426ddSRiver Riddle offsets, sizes, strides); 1247d0426ddSRiver Riddle return success(); 1257d0426ddSRiver Riddle } 1267d0426ddSRiver Riddle }; 1277d0426ddSRiver Riddle 1287d0426ddSRiver Riddle } // namespace 1297d0426ddSRiver Riddle 1307d0426ddSRiver Riddle void mlir::memref::populateComposeSubViewPatterns( 1319f85c198SRiver Riddle RewritePatternSet &patterns, MLIRContext *context) { 132*b4e0507cSTres Popp patterns.add<ComposeSubViewOpPattern>(context); 1337d0426ddSRiver Riddle } 134