1 //===----------- MultiBuffering.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements multi buffering transformation.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/Interfaces/LoopLikeInterface.h"
18
19 using namespace mlir;
20
21 /// Return true if the op fully overwrite the given `buffer` value.
overrideBuffer(Operation * op,Value buffer)22 static bool overrideBuffer(Operation *op, Value buffer) {
23 auto copyOp = dyn_cast<memref::CopyOp>(op);
24 if (!copyOp)
25 return false;
26 return copyOp.getTarget() == buffer;
27 }
28
29 /// Replace the uses of `oldOp` with the given `val` and for subview uses
30 /// propagate the type change. Changing the memref type may require propagating
31 /// it through subview ops so we cannot just do a replaceAllUse but need to
32 /// propagate the type change and erase old subview ops.
replaceUsesAndPropagateType(Operation * oldOp,Value val,OpBuilder & builder)33 static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
34 OpBuilder &builder) {
35 SmallVector<Operation *> opToDelete;
36 SmallVector<OpOperand *> operandsToReplace;
37 for (OpOperand &use : oldOp->getUses()) {
38 auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
39 if (!subviewUse) {
40 // Save the operand to and replace outside the loop to not invalidate the
41 // iterator.
42 operandsToReplace.push_back(&use);
43 continue;
44 }
45 builder.setInsertionPoint(subviewUse);
46 Type newType = memref::SubViewOp::inferRankReducedResultType(
47 subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
48 extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
49 extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
50 extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
51 Value newSubview = builder.create<memref::SubViewOp>(
52 subviewUse->getLoc(), newType.cast<MemRefType>(), val,
53 subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
54 subviewUse.getMixedStrides());
55 replaceUsesAndPropagateType(subviewUse, newSubview, builder);
56 opToDelete.push_back(use.getOwner());
57 }
58 for (OpOperand *operand : operandsToReplace)
59 operand->set(val);
60 // Clean up old subview ops.
61 for (Operation *op : opToDelete)
62 op->erase();
63 }
64
65 /// Helper to convert get a value from an OpFoldResult or create it at the
66 /// builder insert point.
getOrCreateValue(OpFoldResult res,OpBuilder & builder,Location loc)67 static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
68 Location loc) {
69 Value value = res.dyn_cast<Value>();
70 if (value)
71 return value;
72 return builder.create<arith::ConstantIndexOp>(
73 loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt());
74 }
75
76 // Transformation to do multi-buffering/array expansion to remove dependencies
77 // on the temporary allocation between consecutive loop iterations.
78 // Returns success if the transformation happened and failure otherwise.
79 // This is not a pattern as it requires propagating the new memref type to its
80 // uses and requires updating subview ops.
multiBuffer(memref::AllocOp allocOp,unsigned multiplier)81 LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
82 unsigned multiplier) {
83 DominanceInfo dom(allocOp->getParentOp());
84 LoopLikeOpInterface candidateLoop;
85 for (Operation *user : allocOp->getUsers()) {
86 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
87 if (!parentLoop)
88 return failure();
89 /// Make sure there is no loop carried dependency on the allocation.
90 if (!overrideBuffer(user, allocOp.getResult()))
91 continue;
92 // If this user doesn't dominate all the other users keep looking.
93 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
94 return !dom.dominates(user, otherUser);
95 }))
96 continue;
97 candidateLoop = parentLoop;
98 break;
99 }
100 if (!candidateLoop)
101 return failure();
102 llvm::Optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
103 llvm::Optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
104 llvm::Optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
105 if (!inductionVar || !lowerBound || !singleStep)
106 return failure();
107 OpBuilder builder(candidateLoop);
108 Value stepValue =
109 getOrCreateValue(*singleStep, builder, candidateLoop->getLoc());
110 Value lowerBoundValue =
111 getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc());
112 SmallVector<int64_t, 4> newShape(1, multiplier);
113 ArrayRef<int64_t> oldShape = allocOp.getType().getShape();
114 newShape.append(oldShape.begin(), oldShape.end());
115 auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(),
116 MemRefLayoutAttrInterface(),
117 allocOp.getType().getMemorySpace());
118 builder.setInsertionPoint(allocOp);
119 Location loc = allocOp->getLoc();
120 auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref);
121 builder.setInsertionPoint(&candidateLoop.getLoopBody().front(),
122 candidateLoop.getLoopBody().front().begin());
123 AffineExpr induc = getAffineDimExpr(0, allocOp.getContext());
124 AffineExpr init = getAffineDimExpr(1, allocOp.getContext());
125 AffineExpr step = getAffineDimExpr(2, allocOp.getContext());
126 AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier;
127 auto map = AffineMap::get(3, 0, expr);
128 std::array<Value, 3> operands = {*inductionVar, lowerBoundValue, stepValue};
129 Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands);
130 SmallVector<OpFoldResult> offsets, sizes, strides;
131 offsets.push_back(bufferIndex);
132 offsets.append(oldShape.size(), builder.getIndexAttr(0));
133 strides.assign(oldShape.size() + 1, builder.getIndexAttr(1));
134 sizes.push_back(builder.getIndexAttr(1));
135 for (int64_t size : oldShape)
136 sizes.push_back(builder.getIndexAttr(size));
137 auto dstMemref =
138 memref::SubViewOp::inferRankReducedResultType(
139 allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
140 .cast<MemRefType>();
141 Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
142 offsets, sizes, strides);
143 replaceUsesAndPropagateType(allocOp, subview, builder);
144 allocOp.erase();
145 return success();
146 }
147