1b1357fe6SThomas Raoux //===----------- MultiBuffering.cpp ---------------------------------------===//
2b1357fe6SThomas Raoux //
3b1357fe6SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1357fe6SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5b1357fe6SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1357fe6SThomas Raoux //
7b1357fe6SThomas Raoux //===----------------------------------------------------------------------===//
8b1357fe6SThomas Raoux //
9b1357fe6SThomas Raoux // This file implements multi buffering transformation.
10b1357fe6SThomas Raoux //
11b1357fe6SThomas Raoux //===----------------------------------------------------------------------===//
12b1357fe6SThomas Raoux 
13b1357fe6SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
14b1357fe6SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h"
15b1357fe6SThomas Raoux #include "mlir/Dialect/MemRef/Transforms/Passes.h"
16b1357fe6SThomas Raoux #include "mlir/IR/Dominance.h"
17b1357fe6SThomas Raoux #include "mlir/Interfaces/LoopLikeInterface.h"
18b1357fe6SThomas Raoux 
19b1357fe6SThomas Raoux using namespace mlir;
20b1357fe6SThomas Raoux 
21b1357fe6SThomas Raoux /// Return true if the op fully overwrite the given `buffer` value.
overrideBuffer(Operation * op,Value buffer)22b1357fe6SThomas Raoux static bool overrideBuffer(Operation *op, Value buffer) {
23b1357fe6SThomas Raoux   auto copyOp = dyn_cast<memref::CopyOp>(op);
24b1357fe6SThomas Raoux   if (!copyOp)
25b1357fe6SThomas Raoux     return false;
26*136d746eSJacques Pienaar   return copyOp.getTarget() == buffer;
27b1357fe6SThomas Raoux }
28b1357fe6SThomas Raoux 
29b1357fe6SThomas Raoux /// Replace the uses of `oldOp` with the given `val` and for subview uses
30b1357fe6SThomas Raoux /// propagate the type change. Changing the memref type may require propagating
31b1357fe6SThomas Raoux /// it through subview ops so we cannot just do a replaceAllUse but need to
32b1357fe6SThomas Raoux /// propagate the type change and erase old subview ops.
replaceUsesAndPropagateType(Operation * oldOp,Value val,OpBuilder & builder)33b1357fe6SThomas Raoux static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
34b1357fe6SThomas Raoux                                         OpBuilder &builder) {
35b1357fe6SThomas Raoux   SmallVector<Operation *> opToDelete;
36b1357fe6SThomas Raoux   SmallVector<OpOperand *> operandsToReplace;
37b1357fe6SThomas Raoux   for (OpOperand &use : oldOp->getUses()) {
38b1357fe6SThomas Raoux     auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
39b1357fe6SThomas Raoux     if (!subviewUse) {
40b1357fe6SThomas Raoux       // Save the operand to and replace outside the loop to not invalidate the
41b1357fe6SThomas Raoux       // iterator.
42b1357fe6SThomas Raoux       operandsToReplace.push_back(&use);
43b1357fe6SThomas Raoux       continue;
44b1357fe6SThomas Raoux     }
45b1357fe6SThomas Raoux     builder.setInsertionPoint(subviewUse);
46b1357fe6SThomas Raoux     Type newType = memref::SubViewOp::inferRankReducedResultType(
476c3c5f80SMatthias Springer         subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
48*136d746eSJacques Pienaar         extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
49*136d746eSJacques Pienaar         extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
50*136d746eSJacques Pienaar         extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
51b1357fe6SThomas Raoux     Value newSubview = builder.create<memref::SubViewOp>(
52b1357fe6SThomas Raoux         subviewUse->getLoc(), newType.cast<MemRefType>(), val,
53b1357fe6SThomas Raoux         subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
54b1357fe6SThomas Raoux         subviewUse.getMixedStrides());
55b1357fe6SThomas Raoux     replaceUsesAndPropagateType(subviewUse, newSubview, builder);
56b1357fe6SThomas Raoux     opToDelete.push_back(use.getOwner());
57b1357fe6SThomas Raoux   }
58b1357fe6SThomas Raoux   for (OpOperand *operand : operandsToReplace)
59b1357fe6SThomas Raoux     operand->set(val);
60b1357fe6SThomas Raoux   // Clean up old subview ops.
61b1357fe6SThomas Raoux   for (Operation *op : opToDelete)
62b1357fe6SThomas Raoux     op->erase();
63b1357fe6SThomas Raoux }
64b1357fe6SThomas Raoux 
65b1357fe6SThomas Raoux /// Helper to convert get a value from an OpFoldResult or create it at the
66b1357fe6SThomas Raoux /// builder insert point.
getOrCreateValue(OpFoldResult res,OpBuilder & builder,Location loc)67b1357fe6SThomas Raoux static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
68b1357fe6SThomas Raoux                               Location loc) {
69b1357fe6SThomas Raoux   Value value = res.dyn_cast<Value>();
70b1357fe6SThomas Raoux   if (value)
71b1357fe6SThomas Raoux     return value;
72b1357fe6SThomas Raoux   return builder.create<arith::ConstantIndexOp>(
73b1357fe6SThomas Raoux       loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt());
74b1357fe6SThomas Raoux }
75b1357fe6SThomas Raoux 
76b1357fe6SThomas Raoux // Transformation to do multi-buffering/array expansion to remove dependencies
77b1357fe6SThomas Raoux // on the temporary allocation between consecutive loop iterations.
78b1357fe6SThomas Raoux // Returns success if the transformation happened and failure otherwise.
79b1357fe6SThomas Raoux // This is not a pattern as it requires propagating the new memref type to its
80b1357fe6SThomas Raoux // uses and requires updating subview ops.
multiBuffer(memref::AllocOp allocOp,unsigned multiplier)81b1357fe6SThomas Raoux LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
82b1357fe6SThomas Raoux                                         unsigned multiplier) {
83b1357fe6SThomas Raoux   DominanceInfo dom(allocOp->getParentOp());
84b1357fe6SThomas Raoux   LoopLikeOpInterface candidateLoop;
85b1357fe6SThomas Raoux   for (Operation *user : allocOp->getUsers()) {
86b1357fe6SThomas Raoux     auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
87b1357fe6SThomas Raoux     if (!parentLoop)
88b1357fe6SThomas Raoux       return failure();
89b1357fe6SThomas Raoux     /// Make sure there is no loop carried dependency on the allocation.
90b1357fe6SThomas Raoux     if (!overrideBuffer(user, allocOp.getResult()))
91b1357fe6SThomas Raoux       continue;
92b1357fe6SThomas Raoux     // If this user doesn't dominate all the other users keep looking.
93b1357fe6SThomas Raoux     if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
94b1357fe6SThomas Raoux           return !dom.dominates(user, otherUser);
95b1357fe6SThomas Raoux         }))
96b1357fe6SThomas Raoux       continue;
97b1357fe6SThomas Raoux     candidateLoop = parentLoop;
98b1357fe6SThomas Raoux     break;
99b1357fe6SThomas Raoux   }
100b1357fe6SThomas Raoux   if (!candidateLoop)
101b1357fe6SThomas Raoux     return failure();
102b1357fe6SThomas Raoux   llvm::Optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
103b1357fe6SThomas Raoux   llvm::Optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
104b1357fe6SThomas Raoux   llvm::Optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
105b1357fe6SThomas Raoux   if (!inductionVar || !lowerBound || !singleStep)
106b1357fe6SThomas Raoux     return failure();
107b1357fe6SThomas Raoux   OpBuilder builder(candidateLoop);
108b1357fe6SThomas Raoux   Value stepValue =
109b1357fe6SThomas Raoux       getOrCreateValue(*singleStep, builder, candidateLoop->getLoc());
110b1357fe6SThomas Raoux   Value lowerBoundValue =
111b1357fe6SThomas Raoux       getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc());
112b1357fe6SThomas Raoux   SmallVector<int64_t, 4> newShape(1, multiplier);
113b1357fe6SThomas Raoux   ArrayRef<int64_t> oldShape = allocOp.getType().getShape();
114b1357fe6SThomas Raoux   newShape.append(oldShape.begin(), oldShape.end());
115b1357fe6SThomas Raoux   auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(),
116b1357fe6SThomas Raoux                                    MemRefLayoutAttrInterface(),
117b1357fe6SThomas Raoux                                    allocOp.getType().getMemorySpace());
118b1357fe6SThomas Raoux   builder.setInsertionPoint(allocOp);
119b1357fe6SThomas Raoux   Location loc = allocOp->getLoc();
120b1357fe6SThomas Raoux   auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref);
121b1357fe6SThomas Raoux   builder.setInsertionPoint(&candidateLoop.getLoopBody().front(),
122b1357fe6SThomas Raoux                             candidateLoop.getLoopBody().front().begin());
123b1357fe6SThomas Raoux   AffineExpr induc = getAffineDimExpr(0, allocOp.getContext());
124b1357fe6SThomas Raoux   AffineExpr init = getAffineDimExpr(1, allocOp.getContext());
125b1357fe6SThomas Raoux   AffineExpr step = getAffineDimExpr(2, allocOp.getContext());
126b1357fe6SThomas Raoux   AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier;
127b1357fe6SThomas Raoux   auto map = AffineMap::get(3, 0, expr);
128b1357fe6SThomas Raoux   std::array<Value, 3> operands = {*inductionVar, lowerBoundValue, stepValue};
129b1357fe6SThomas Raoux   Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands);
130b1357fe6SThomas Raoux   SmallVector<OpFoldResult> offsets, sizes, strides;
131b1357fe6SThomas Raoux   offsets.push_back(bufferIndex);
132b1357fe6SThomas Raoux   offsets.append(oldShape.size(), builder.getIndexAttr(0));
133b1357fe6SThomas Raoux   strides.assign(oldShape.size() + 1, builder.getIndexAttr(1));
134b1357fe6SThomas Raoux   sizes.push_back(builder.getIndexAttr(1));
135b1357fe6SThomas Raoux   for (int64_t size : oldShape)
136b1357fe6SThomas Raoux     sizes.push_back(builder.getIndexAttr(size));
137b1357fe6SThomas Raoux   auto dstMemref =
138b1357fe6SThomas Raoux       memref::SubViewOp::inferRankReducedResultType(
1396c3c5f80SMatthias Springer           allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
140b1357fe6SThomas Raoux           .cast<MemRefType>();
141b1357fe6SThomas Raoux   Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
142b1357fe6SThomas Raoux                                                     offsets, sizes, strides);
143b1357fe6SThomas Raoux   replaceUsesAndPropagateType(allocOp, subview, builder);
144b1357fe6SThomas Raoux   allocOp.erase();
145b1357fe6SThomas Raoux   return success();
146b1357fe6SThomas Raoux }
147