1 //===----------- MultiBuffering.cpp ---------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements multi buffering transformation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 16 #include "mlir/IR/Dominance.h" 17 #include "mlir/Interfaces/LoopLikeInterface.h" 18 19 using namespace mlir; 20 21 /// Return true if the op fully overwrite the given `buffer` value. 22 static bool overrideBuffer(Operation *op, Value buffer) { 23 auto copyOp = dyn_cast<memref::CopyOp>(op); 24 if (!copyOp) 25 return false; 26 return copyOp.getTarget() == buffer; 27 } 28 29 /// Replace the uses of `oldOp` with the given `val` and for subview uses 30 /// propagate the type change. Changing the memref type may require propagating 31 /// it through subview ops so we cannot just do a replaceAllUse but need to 32 /// propagate the type change and erase old subview ops. 33 static void replaceUsesAndPropagateType(Operation *oldOp, Value val, 34 OpBuilder &builder) { 35 SmallVector<Operation *> opToDelete; 36 SmallVector<OpOperand *> operandsToReplace; 37 for (OpOperand &use : oldOp->getUses()) { 38 auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner()); 39 if (!subviewUse) { 40 // Save the operand to and replace outside the loop to not invalidate the 41 // iterator. 42 operandsToReplace.push_back(&use); 43 continue; 44 } 45 builder.setInsertionPoint(subviewUse); 46 Type newType = memref::SubViewOp::inferRankReducedResultType( 47 subviewUse.getType().getShape(), val.getType().cast<MemRefType>(), 48 extractFromI64ArrayAttr(subviewUse.getStaticOffsets()), 49 extractFromI64ArrayAttr(subviewUse.getStaticSizes()), 50 extractFromI64ArrayAttr(subviewUse.getStaticStrides())); 51 Value newSubview = builder.create<memref::SubViewOp>( 52 subviewUse->getLoc(), newType.cast<MemRefType>(), val, 53 subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), 54 subviewUse.getMixedStrides()); 55 replaceUsesAndPropagateType(subviewUse, newSubview, builder); 56 opToDelete.push_back(use.getOwner()); 57 } 58 for (OpOperand *operand : operandsToReplace) 59 operand->set(val); 60 // Clean up old subview ops. 61 for (Operation *op : opToDelete) 62 op->erase(); 63 } 64 65 /// Helper to convert get a value from an OpFoldResult or create it at the 66 /// builder insert point. 67 static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder, 68 Location loc) { 69 Value value = res.dyn_cast<Value>(); 70 if (value) 71 return value; 72 return builder.create<arith::ConstantIndexOp>( 73 loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt()); 74 } 75 76 // Transformation to do multi-buffering/array expansion to remove dependencies 77 // on the temporary allocation between consecutive loop iterations. 78 // Returns success if the transformation happened and failure otherwise. 79 // This is not a pattern as it requires propagating the new memref type to its 80 // uses and requires updating subview ops. 81 LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp, 82 unsigned multiplier) { 83 DominanceInfo dom(allocOp->getParentOp()); 84 LoopLikeOpInterface candidateLoop; 85 for (Operation *user : allocOp->getUsers()) { 86 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>(); 87 if (!parentLoop) 88 return failure(); 89 /// Make sure there is no loop carried dependency on the allocation. 90 if (!overrideBuffer(user, allocOp.getResult())) 91 continue; 92 // If this user doesn't dominate all the other users keep looking. 93 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { 94 return !dom.dominates(user, otherUser); 95 })) 96 continue; 97 candidateLoop = parentLoop; 98 break; 99 } 100 if (!candidateLoop) 101 return failure(); 102 llvm::Optional<Value> inductionVar = candidateLoop.getSingleInductionVar(); 103 llvm::Optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound(); 104 llvm::Optional<OpFoldResult> singleStep = candidateLoop.getSingleStep(); 105 if (!inductionVar || !lowerBound || !singleStep) 106 return failure(); 107 OpBuilder builder(candidateLoop); 108 Value stepValue = 109 getOrCreateValue(*singleStep, builder, candidateLoop->getLoc()); 110 Value lowerBoundValue = 111 getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc()); 112 SmallVector<int64_t, 4> newShape(1, multiplier); 113 ArrayRef<int64_t> oldShape = allocOp.getType().getShape(); 114 newShape.append(oldShape.begin(), oldShape.end()); 115 auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(), 116 MemRefLayoutAttrInterface(), 117 allocOp.getType().getMemorySpace()); 118 builder.setInsertionPoint(allocOp); 119 Location loc = allocOp->getLoc(); 120 auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref); 121 builder.setInsertionPoint(&candidateLoop.getLoopBody().front(), 122 candidateLoop.getLoopBody().front().begin()); 123 AffineExpr induc = getAffineDimExpr(0, allocOp.getContext()); 124 AffineExpr init = getAffineDimExpr(1, allocOp.getContext()); 125 AffineExpr step = getAffineDimExpr(2, allocOp.getContext()); 126 AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier; 127 auto map = AffineMap::get(3, 0, expr); 128 std::array<Value, 3> operands = {*inductionVar, lowerBoundValue, stepValue}; 129 Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands); 130 SmallVector<OpFoldResult> offsets, sizes, strides; 131 offsets.push_back(bufferIndex); 132 offsets.append(oldShape.size(), builder.getIndexAttr(0)); 133 strides.assign(oldShape.size() + 1, builder.getIndexAttr(1)); 134 sizes.push_back(builder.getIndexAttr(1)); 135 for (int64_t size : oldShape) 136 sizes.push_back(builder.getIndexAttr(size)); 137 auto dstMemref = 138 memref::SubViewOp::inferRankReducedResultType( 139 allocOp.getType().getShape(), newMemref, offsets, sizes, strides) 140 .cast<MemRefType>(); 141 Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc, 142 offsets, sizes, strides); 143 replaceUsesAndPropagateType(allocOp, subview, builder); 144 allocOp.erase(); 145 return success(); 146 } 147