1 //===----------- MultiBuffering.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements multi buffering transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/Interfaces/LoopLikeInterface.h"
18 
19 using namespace mlir;
20 
21 /// Return true if the op fully overwrite the given `buffer` value.
22 static bool overrideBuffer(Operation *op, Value buffer) {
23   auto copyOp = dyn_cast<memref::CopyOp>(op);
24   if (!copyOp)
25     return false;
26   return copyOp.getTarget() == buffer;
27 }
28 
29 /// Replace the uses of `oldOp` with the given `val` and for subview uses
30 /// propagate the type change. Changing the memref type may require propagating
31 /// it through subview ops so we cannot just do a replaceAllUse but need to
32 /// propagate the type change and erase old subview ops.
33 static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
34                                         OpBuilder &builder) {
35   SmallVector<Operation *> opToDelete;
36   SmallVector<OpOperand *> operandsToReplace;
37   for (OpOperand &use : oldOp->getUses()) {
38     auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
39     if (!subviewUse) {
40       // Save the operand to and replace outside the loop to not invalidate the
41       // iterator.
42       operandsToReplace.push_back(&use);
43       continue;
44     }
45     builder.setInsertionPoint(subviewUse);
46     Type newType = memref::SubViewOp::inferRankReducedResultType(
47         subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
48         extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
49         extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
50         extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
51     Value newSubview = builder.create<memref::SubViewOp>(
52         subviewUse->getLoc(), newType.cast<MemRefType>(), val,
53         subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
54         subviewUse.getMixedStrides());
55     replaceUsesAndPropagateType(subviewUse, newSubview, builder);
56     opToDelete.push_back(use.getOwner());
57   }
58   for (OpOperand *operand : operandsToReplace)
59     operand->set(val);
60   // Clean up old subview ops.
61   for (Operation *op : opToDelete)
62     op->erase();
63 }
64 
65 /// Helper to convert get a value from an OpFoldResult or create it at the
66 /// builder insert point.
67 static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
68                               Location loc) {
69   Value value = res.dyn_cast<Value>();
70   if (value)
71     return value;
72   return builder.create<arith::ConstantIndexOp>(
73       loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt());
74 }
75 
76 // Transformation to do multi-buffering/array expansion to remove dependencies
77 // on the temporary allocation between consecutive loop iterations.
78 // Returns success if the transformation happened and failure otherwise.
79 // This is not a pattern as it requires propagating the new memref type to its
80 // uses and requires updating subview ops.
81 LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
82                                         unsigned multiplier) {
83   DominanceInfo dom(allocOp->getParentOp());
84   LoopLikeOpInterface candidateLoop;
85   for (Operation *user : allocOp->getUsers()) {
86     auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
87     if (!parentLoop)
88       return failure();
89     /// Make sure there is no loop carried dependency on the allocation.
90     if (!overrideBuffer(user, allocOp.getResult()))
91       continue;
92     // If this user doesn't dominate all the other users keep looking.
93     if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
94           return !dom.dominates(user, otherUser);
95         }))
96       continue;
97     candidateLoop = parentLoop;
98     break;
99   }
100   if (!candidateLoop)
101     return failure();
102   llvm::Optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
103   llvm::Optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
104   llvm::Optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
105   if (!inductionVar || !lowerBound || !singleStep)
106     return failure();
107   OpBuilder builder(candidateLoop);
108   Value stepValue =
109       getOrCreateValue(*singleStep, builder, candidateLoop->getLoc());
110   Value lowerBoundValue =
111       getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc());
112   SmallVector<int64_t, 4> newShape(1, multiplier);
113   ArrayRef<int64_t> oldShape = allocOp.getType().getShape();
114   newShape.append(oldShape.begin(), oldShape.end());
115   auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(),
116                                    MemRefLayoutAttrInterface(),
117                                    allocOp.getType().getMemorySpace());
118   builder.setInsertionPoint(allocOp);
119   Location loc = allocOp->getLoc();
120   auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref);
121   builder.setInsertionPoint(&candidateLoop.getLoopBody().front(),
122                             candidateLoop.getLoopBody().front().begin());
123   AffineExpr induc = getAffineDimExpr(0, allocOp.getContext());
124   AffineExpr init = getAffineDimExpr(1, allocOp.getContext());
125   AffineExpr step = getAffineDimExpr(2, allocOp.getContext());
126   AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier;
127   auto map = AffineMap::get(3, 0, expr);
128   std::array<Value, 3> operands = {*inductionVar, lowerBoundValue, stepValue};
129   Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands);
130   SmallVector<OpFoldResult> offsets, sizes, strides;
131   offsets.push_back(bufferIndex);
132   offsets.append(oldShape.size(), builder.getIndexAttr(0));
133   strides.assign(oldShape.size() + 1, builder.getIndexAttr(1));
134   sizes.push_back(builder.getIndexAttr(1));
135   for (int64_t size : oldShape)
136     sizes.push_back(builder.getIndexAttr(size));
137   auto dstMemref =
138       memref::SubViewOp::inferRankReducedResultType(
139           allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
140           .cast<MemRefType>();
141   Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
142                                                     offsets, sizes, strides);
143   replaceUsesAndPropagateType(allocOp, subview, builder);
144   allocOp.erase();
145   return success();
146 }
147