//===----------- MultiBuffering.cpp ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements multi buffering transformation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/LoopLikeInterface.h" using namespace mlir; /// Return true if the op fully overwrite the given `buffer` value. static bool overrideBuffer(Operation *op, Value buffer) { auto copyOp = dyn_cast(op); if (!copyOp) return false; return copyOp.getTarget() == buffer; } /// Replace the uses of `oldOp` with the given `val` and for subview uses /// propagate the type change. Changing the memref type may require propagating /// it through subview ops so we cannot just do a replaceAllUse but need to /// propagate the type change and erase old subview ops. static void replaceUsesAndPropagateType(Operation *oldOp, Value val, OpBuilder &builder) { SmallVector opToDelete; SmallVector operandsToReplace; for (OpOperand &use : oldOp->getUses()) { auto subviewUse = dyn_cast(use.getOwner()); if (!subviewUse) { // Save the operand to and replace outside the loop to not invalidate the // iterator. operandsToReplace.push_back(&use); continue; } builder.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( subviewUse.getType().getShape(), val.getType().cast(), extractFromI64ArrayAttr(subviewUse.getStaticOffsets()), extractFromI64ArrayAttr(subviewUse.getStaticSizes()), extractFromI64ArrayAttr(subviewUse.getStaticStrides())); Value newSubview = builder.create( subviewUse->getLoc(), newType.cast(), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); replaceUsesAndPropagateType(subviewUse, newSubview, builder); opToDelete.push_back(use.getOwner()); } for (OpOperand *operand : operandsToReplace) operand->set(val); // Clean up old subview ops. for (Operation *op : opToDelete) op->erase(); } /// Helper to convert get a value from an OpFoldResult or create it at the /// builder insert point. static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder, Location loc) { Value value = res.dyn_cast(); if (value) return value; return builder.create( loc, res.dyn_cast().cast().getInt()); } // Transformation to do multi-buffering/array expansion to remove dependencies // on the temporary allocation between consecutive loop iterations. // Returns success if the transformation happened and failure otherwise. // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier) { DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; for (Operation *user : allocOp->getUsers()) { auto parentLoop = user->getParentOfType(); if (!parentLoop) return failure(); /// Make sure there is no loop carried dependency on the allocation. if (!overrideBuffer(user, allocOp.getResult())) continue; // If this user doesn't dominate all the other users keep looking. if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { return !dom.dominates(user, otherUser); })) continue; candidateLoop = parentLoop; break; } if (!candidateLoop) return failure(); llvm::Optional inductionVar = candidateLoop.getSingleInductionVar(); llvm::Optional lowerBound = candidateLoop.getSingleLowerBound(); llvm::Optional singleStep = candidateLoop.getSingleStep(); if (!inductionVar || !lowerBound || !singleStep) return failure(); OpBuilder builder(candidateLoop); Value stepValue = getOrCreateValue(*singleStep, builder, candidateLoop->getLoc()); Value lowerBoundValue = getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc()); SmallVector newShape(1, multiplier); ArrayRef oldShape = allocOp.getType().getShape(); newShape.append(oldShape.begin(), oldShape.end()); auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(), MemRefLayoutAttrInterface(), allocOp.getType().getMemorySpace()); builder.setInsertionPoint(allocOp); Location loc = allocOp->getLoc(); auto newAlloc = builder.create(loc, newMemref); builder.setInsertionPoint(&candidateLoop.getLoopBody().front(), candidateLoop.getLoopBody().front().begin()); AffineExpr induc = getAffineDimExpr(0, allocOp.getContext()); AffineExpr init = getAffineDimExpr(1, allocOp.getContext()); AffineExpr step = getAffineDimExpr(2, allocOp.getContext()); AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier; auto map = AffineMap::get(3, 0, expr); std::array operands = {*inductionVar, lowerBoundValue, stepValue}; Value bufferIndex = builder.create(loc, map, operands); SmallVector offsets, sizes, strides; offsets.push_back(bufferIndex); offsets.append(oldShape.size(), builder.getIndexAttr(0)); strides.assign(oldShape.size() + 1, builder.getIndexAttr(1)); sizes.push_back(builder.getIndexAttr(1)); for (int64_t size : oldShape) sizes.push_back(builder.getIndexAttr(size)); auto dstMemref = memref::SubViewOp::inferRankReducedResultType( allocOp.getType().getShape(), newMemref, offsets, sizes, strides) .cast(); Value subview = builder.create(loc, dstMemref, newAlloc, offsets, sizes, strides); replaceUsesAndPropagateType(allocOp, subview, builder); allocOp.erase(); return success(); }