17d0426ddSRiver Riddle //===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
27d0426ddSRiver Riddle //
37d0426ddSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47d0426ddSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57d0426ddSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67d0426ddSRiver Riddle //
77d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
87d0426ddSRiver Riddle //
97d0426ddSRiver Riddle // This file contains patterns for combining composed subview ops (i.e. subview
107d0426ddSRiver Riddle // of a subview becomes a single subview).
117d0426ddSRiver Riddle //
127d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
137d0426ddSRiver Riddle 
147d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
157d0426ddSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
167d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
177d0426ddSRiver Riddle #include "mlir/IR/BuiltinAttributes.h"
187d0426ddSRiver Riddle #include "mlir/IR/OpDefinition.h"
197d0426ddSRiver Riddle #include "mlir/IR/PatternMatch.h"
207d0426ddSRiver Riddle #include "mlir/Transforms/DialectConversion.h"
217d0426ddSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
227d0426ddSRiver Riddle 
237d0426ddSRiver Riddle using namespace mlir;
247d0426ddSRiver Riddle 
257d0426ddSRiver Riddle namespace {
267d0426ddSRiver Riddle 
277d0426ddSRiver Riddle // Replaces a subview of a subview with a single subview. Only supports subview
287d0426ddSRiver Riddle // ops with static sizes and static strides of 1 (both static and dynamic
297d0426ddSRiver Riddle // offsets are supported).
307d0426ddSRiver Riddle struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
317d0426ddSRiver Riddle   using OpRewritePattern::OpRewritePattern;
327d0426ddSRiver Riddle 
matchAndRewrite__anond24980220111::ComposeSubViewOpPattern337d0426ddSRiver Riddle   LogicalResult matchAndRewrite(memref::SubViewOp op,
347d0426ddSRiver Riddle                                 PatternRewriter &rewriter) const override {
357d0426ddSRiver Riddle     // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
367d0426ddSRiver Riddle     // produces the input of the op we're rewriting (for 'SubViewOp' the input
377d0426ddSRiver Riddle     // is called the "source" value). We can only combine them if both 'op' and
387d0426ddSRiver Riddle     // 'sourceOp' are 'SubViewOp'.
39136d746eSJacques Pienaar     auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
407d0426ddSRiver Riddle     if (!sourceOp)
417d0426ddSRiver Riddle       return failure();
427d0426ddSRiver Riddle 
437d0426ddSRiver Riddle     // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
447d0426ddSRiver Riddle     // output memref that are statically known to be equal to 1. We do not
457d0426ddSRiver Riddle     // allow 'sourceOp' to be a rank-reducing subview because then our two
467d0426ddSRiver Riddle     // 'SubViewOp's would have different numbers of offset/size/stride
477d0426ddSRiver Riddle     // parameters (just difficult to deal with, not impossible if we end up
487d0426ddSRiver Riddle     // needing it).
497d0426ddSRiver Riddle     if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
507d0426ddSRiver Riddle       return failure();
517d0426ddSRiver Riddle     }
527d0426ddSRiver Riddle 
537d0426ddSRiver Riddle     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
547d0426ddSRiver Riddle     SmallVector<OpFoldResult> offsets, sizes, strides;
557d0426ddSRiver Riddle 
567d0426ddSRiver Riddle     // Because we only support input strides of 1, the output stride is also
577d0426ddSRiver Riddle     // always 1.
587d0426ddSRiver Riddle     if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
597d0426ddSRiver Riddle           Attribute attr = valueOrAttr.dyn_cast<Attribute>();
607d0426ddSRiver Riddle           return attr && attr.cast<IntegerAttr>().getInt() == 1;
617d0426ddSRiver Riddle         })) {
627d0426ddSRiver Riddle       strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
637d0426ddSRiver Riddle                                           rewriter.getI64IntegerAttr(1));
647d0426ddSRiver Riddle     } else {
657d0426ddSRiver Riddle       return failure();
667d0426ddSRiver Riddle     }
677d0426ddSRiver Riddle 
687d0426ddSRiver Riddle     // The rules for calculating the new offsets and sizes are:
697d0426ddSRiver Riddle     // * Multiple subview offsets for a given dimension compose additively.
707d0426ddSRiver Riddle     //   ("Offset by m" followed by "Offset by n" == "Offset by m + n")
717d0426ddSRiver Riddle     // * Multiple sizes for a given dimension compose by taking the size of the
727d0426ddSRiver Riddle     //   final subview and ignoring the rest. ("Take m values" followed by "Take
737d0426ddSRiver Riddle     //   n values" == "Take n values") This size must also be the smallest one
747d0426ddSRiver Riddle     //   by definition (a subview needs to be the same size as or smaller than
757d0426ddSRiver Riddle     //   its source along each dimension; presumably subviews that are larger
767d0426ddSRiver Riddle     //   than their sources are disallowed by validation).
777d0426ddSRiver Riddle     for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
787d0426ddSRiver Riddle                              op.getMixedSizes())) {
797d0426ddSRiver Riddle       auto opOffset = std::get<0>(it);
807d0426ddSRiver Riddle       auto sourceOffset = std::get<1>(it);
817d0426ddSRiver Riddle       auto opSize = std::get<2>(it);
827d0426ddSRiver Riddle 
837d0426ddSRiver Riddle       // We only support static sizes.
847d0426ddSRiver Riddle       if (opSize.is<Value>()) {
857d0426ddSRiver Riddle         return failure();
867d0426ddSRiver Riddle       }
877d0426ddSRiver Riddle 
887d0426ddSRiver Riddle       sizes.push_back(opSize);
897d0426ddSRiver Riddle       Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
907d0426ddSRiver Riddle                 sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
917d0426ddSRiver Riddle 
927d0426ddSRiver Riddle       if (opOffsetAttr && sourceOffsetAttr) {
937d0426ddSRiver Riddle         // If both offsets are static we can simply calculate the combined
947d0426ddSRiver Riddle         // offset statically.
957d0426ddSRiver Riddle         offsets.push_back(rewriter.getI64IntegerAttr(
967d0426ddSRiver Riddle             opOffsetAttr.cast<IntegerAttr>().getInt() +
977d0426ddSRiver Riddle             sourceOffsetAttr.cast<IntegerAttr>().getInt()));
987d0426ddSRiver Riddle       } else {
997d0426ddSRiver Riddle         // When either offset is dynamic, we must emit an additional affine
1007d0426ddSRiver Riddle         // transformation to add the two offsets together dynamically.
1017d0426ddSRiver Riddle         AffineExpr expr = rewriter.getAffineConstantExpr(0);
1027d0426ddSRiver Riddle         SmallVector<Value> affineApplyOperands;
1037d0426ddSRiver Riddle         for (auto valueOrAttr : {opOffset, sourceOffset}) {
1047d0426ddSRiver Riddle           if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
1057d0426ddSRiver Riddle             expr = expr + attr.cast<IntegerAttr>().getInt();
1067d0426ddSRiver Riddle           } else {
1077d0426ddSRiver Riddle             expr =
1087d0426ddSRiver Riddle                 expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
1097d0426ddSRiver Riddle             affineApplyOperands.push_back(valueOrAttr.get<Value>());
1107d0426ddSRiver Riddle           }
1117d0426ddSRiver Riddle         }
1127d0426ddSRiver Riddle 
1137d0426ddSRiver Riddle         AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
1147d0426ddSRiver Riddle         Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map,
1157d0426ddSRiver Riddle                                                       affineApplyOperands);
1167d0426ddSRiver Riddle         offsets.push_back(result);
1177d0426ddSRiver Riddle       }
1187d0426ddSRiver Riddle     }
1197d0426ddSRiver Riddle 
1207d0426ddSRiver Riddle     // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
1217d0426ddSRiver Riddle     // uses it can be removed by a (separate) dead code elimination pass.
122136d746eSJacques Pienaar     rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
1237d0426ddSRiver Riddle                                                    offsets, sizes, strides);
1247d0426ddSRiver Riddle     return success();
1257d0426ddSRiver Riddle   }
1267d0426ddSRiver Riddle };
1277d0426ddSRiver Riddle 
1287d0426ddSRiver Riddle } // namespace
1297d0426ddSRiver Riddle 
populateComposeSubViewPatterns(RewritePatternSet & patterns,MLIRContext * context)130*b7f93c28SJeff Niu void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
131*b7f93c28SJeff Niu                                                   MLIRContext *context) {
132b4e0507cSTres Popp   patterns.add<ComposeSubViewOpPattern>(context);
1337d0426ddSRiver Riddle }
134