//===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "llvm/Support/Debug.h" namespace mlir { namespace bufferization { #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" } // namespace bufferization } // namespace mlir #define DEBUG_TYPE "bufferizable-op-interface" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) using namespace mlir; using namespace bufferization; //===----------------------------------------------------------------------===// // BufferizationOptions //===----------------------------------------------------------------------===// // Default constructor for BufferizationOptions. BufferizationOptions::BufferizationOptions() {} BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Operation *op) const { if (isOpAllowed(op)) return dyn_cast(op); return nullptr; } BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Value value) const { if (auto bufferizableOp = value.getDefiningOp()) if (isOpAllowed(bufferizableOp.getOperation())) return bufferizableOp; return nullptr; } //===----------------------------------------------------------------------===// // Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// static void setInsertionPointAfter(OpBuilder &b, Value value) { if (auto bbArg = value.dyn_cast()) { b.setInsertionPointToStart(bbArg.getOwner()); } else { b.setInsertionPointAfter(value.getDefiningOp()); } } /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector BufferizationState::getAliasingOpOperand(OpResult result) const { if (Operation *op = result.getDefiningOp()) if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.getAliasingOpOperand(result, *this); return {}; } /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. Return an empty OpResult if the op is not bufferizable. OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.getAliasingOpResult(opOperand, *this); return OpResult(); } /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); // Unknown op that returns a tensor. The inplace analysis does not support it. // Conservatively return true. return true; } /// Return true if `opOperand` bufferizes to a memory write. Return /// `true` if the op is not bufferizable. bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); // Unknown op that returns a tensor. The inplace analysis does not support it. // Conservatively return true. return true; } /// Return true if `opOperand` does neither read nor write but bufferizes to an /// alias. Return false if the op is not bufferizable. bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); // Unknown op that returns a tensor. The inplace analysis does not support it. // Conservatively return false. return false; } /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). bool BufferizationState::isValueRead(Value value) const { assert(value.getType().isa() && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; } return false; } // Starting from `value`, follow the use-def chain in reverse, always selecting // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any // further. llvm::SetVector BufferizationState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition) const { llvm::SetVector result, workingSet; workingSet.insert(value); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); if (condition(value) || value.isa()) { result.insert(value); continue; } OpResult opResult = value.cast(); SmallVector opOperands = getAliasingOpOperand(opResult); if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { result.insert(value); continue; } for (OpOperand *o : opOperands) workingSet.insert(o->get()); } return result; } // Find the Values of the last preceding write of a given Value. llvm::SetVector BufferizationState::findLastPrecedingWrite(Value value) const { return findValueInReverseUseDefChain(value, [&](Value value) { Operation *op = value.getDefiningOp(); if (!op) return true; auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) return true; return bufferizableOp.isMemoryWrite(value.cast(), *this); }); } BufferizationState::BufferizationState(const BufferizationOptions &options) : options(options) {} // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG auto rankedTensorType = tensor.getType().dyn_cast(); assert((!rankedTensorType || memrefType.cast().getRank() == rankedTensorType.getRank()) && "to_memref would be invalid: mismatching ranks"); #endif } static Value lookupBuffer(RewriterBase &rewriter, Value tensor) { assert(tensor.getType().isa() && "unexpected non-tensor type"); // Replace "%t = to_tensor %m" with %m. if (auto toTensorOp = tensor.getDefiningOp()) return toTensorOp.memref(); // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, tensor); Type memrefType; if (auto rankedTensorType = tensor.getType().dyn_cast()) { memrefType = getDynamicMemRefType(rankedTensorType); } else { memrefType = getUnrankedMemRefType( tensor.getType().cast().getElementType()); } ensureToMemrefOpIsValid(tensor, memrefType); return rewriter.create(tensor.getLoc(), memrefType, tensor); } /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. FailureOr BufferizationState::getBuffer( RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, Optional customCopyInsertionPoint) const { OpBuilder::InsertionGuard guard(rewriter); Operation *op = opOperand.getOwner(); Location loc = op->getLoc(); Value operand = opOperand.get(); Value operandBuffer = lookupBuffer(rewriter, operand); if (forceInPlace || isInPlace(opOperand)) return operandBuffer; // Bufferizing out-of-place: Allocate a new buffer. // Move insertion point right after `operandBuffer`. That is where the // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); // Allocate the result buffer. FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer, options.createDeallocs, options); if (failed(resultBuffer)) return failure(); // Do not copy if the last preceding writes of `operand` are ops that do // not write (skipping ops that merely create aliases). E.g., InitTensorOp. // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA // use-def chain, it returns that value, regardless of whether it is a // memory write or not. SetVector lastWrites = findLastPrecedingWrite(operand); if (llvm::none_of(lastWrites, [&](Value lastWrite) { if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) return bufferizableOp.isMemoryWrite(lastWrite.cast(), *this); return true; })) return resultBuffer; // Do not copy if the copied data is never read. OpResult aliasingOpResult = getAliasingOpResult(opOperand); if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) && !isValueRead(aliasingOpResult)) return resultBuffer; // Do not copy if this op does not read the data, but writes it. if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) return resultBuffer; if (customCopyInsertionPoint) { rewriter.setInsertionPoint(*customCopyInsertionPoint); } else { // The copy happens right before the op that is bufferized. rewriter.setInsertionPoint(op); } if (failed( createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) return failure(); return resultBuffer; } void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values) { OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. for (OpResult opResult : op->getOpResults()) { // Skip OpResult if it has no uses. if (opResult.getUses().empty()) continue; Value replacement = values[opResult.getResultNumber()]; if (opResult.getType().isa()) { // The OpResult is a tensor. Such values are replaced with memrefs during // bufferization. assert((replacement.getType().isa() || replacement.getType().isa()) && "tensor op result should be replaced with a memref value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually // loose all of its users and eventually DCE away. setInsertionPointAfter(rewriter, replacement); replacement = rewriter.create( replacement.getLoc(), replacement); } opResult.replaceAllUsesWith(replacement); } rewriter.eraseOp(op); } //===----------------------------------------------------------------------===// // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// /// Move the insertion point of the given builder to the beginning of a /// surrounding block as much as possible, while not crossing any allocation /// hoisting barriers. static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { Operation *op = b.getInsertionBlock()->getParentOp(); while (op) { if (auto bufferizableOp = dyn_cast(op)) if (bufferizableOp.isAllocationHoistingBarrier()) break; op = op->getParentOp(); } if (!op) { // No allocation hoisting barrier found. Hoist to FuncOp. op = b.getInsertionBlock()->getParentOp(); if (!isa(op)) op = op->getParentOfType(); assert(op && "could not find enclosing FuncOp"); } // TODO: Handle cases where allocation hoisting barrier has more than one // region or block. assert(op->getNumRegions() == 1 && "allocation hoisting barriers with >1 regions not supported"); assert(op->getRegion(0).getBlocks().size() == 1 && "allocation hoisting barriers with >1 blocks not supported"); b.setInsertionPointToStart(&(op->getRegion(0).front())); } /// Compute the type of the `memref` to use for allocating the buffer for /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the /// dynamic dimensions in the returned `memref` type. The function may also set /// the insertion point to an earlier location, where the allocation should /// happen ("allocation hoisting"). static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, Value shapedValue, SmallVectorImpl &dynShape) { MemRefType allocMemRefType = getContiguousMemRefType(shapedValue.getType().cast()); // Compute the dynamic part of the shape. bool reifiedShapes = false; if (auto rankedOp = dyn_cast_or_null( shapedValue.getDefiningOp())) { ReifiedRankedShapedTypeDims resultDims; if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { reifiedShapes = true; OpResult resultValue = shapedValue.dyn_cast(); auto &shape = resultDims[resultValue.getResultNumber()]; for (const auto &dim : enumerate(allocMemRefType.getShape())) if (ShapedType::isDynamic(dim.value())) dynShape.push_back(shape[dim.index()]); } } if (!reifiedShapes) { for (const auto &dim : enumerate(allocMemRefType.getShape())) if (ShapedType::isDynamic(dim.value())) { assert((shapedValue.getType().isa() || shapedValue.getType().isa()) && "expected MemRef type"); dynShape.push_back( b.create(loc, shapedValue, dim.index())); } } // If the buffer is statically shaped, try to hoist it to the first enclosing // parallel region. // TODO: also hoist in the dynamic case. For now this relies on subsequent // calls to LICM and buffer hoisting which will most likely not succeed. // TODO: when packing, allocate a static bounding box which will enable more // hoisting. if (dynShape.empty()) moveInsertionPointToAllocationHoistingBarrier(b); return allocMemRefType; } /// Create an AllocOp/DeallocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. FailureOr bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref, const BufferizationOptions &options) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); // 1. Create memory allocation. assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); SmallVector dynShape; // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); FailureOr allocated = createAlloc(b, loc, allocMemRefType, dynShape, options); if (failed(allocated)) return failure(); Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(), memRefType) && "createAlloc: cast incompatible"); casted = b.create(loc, memRefType, allocated.getValue()); } if (deallocMemref) { // 2. Create memory deallocation. b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); if (failed(createDealloc(b, loc, allocated.getValue(), options))) return failure(); } return casted; } /// Create a memref allocation. FailureOr bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape, const BufferizationOptions &options) { if (options.allocationFn) return (*options.allocationFn)(b, loc, type, dynShape); // Default bufferallocation via AllocOp. Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; } /// Create a memref deallocation. LogicalResult bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, const BufferizationOptions &options) { if (options.deallocationFn) return (*options.deallocationFn)(b, loc, allocatedBuffer); // Default buffer deallocation via DeallocOp. b.create(loc, allocatedBuffer); return success(); } /// Create a memory copy between two memref buffers. LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, Value from, Value to, const BufferizationOptions &options) { if (options.memCpyFn) return (*options.memCpyFn)(b, loc, from, to); b.create(loc, from, to); return success(); } //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// bool bufferization::isFunctionArgument(Value value) { auto bbArg = value.dyn_cast(); if (!bbArg) return false; return isa(bbArg.getOwner()->getParentOp()); } MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), layout, memorySpace); } UnrankedMemRefType bufferization::getUnrankedMemRefType(Type elementType, Attribute memorySpace) { return UnrankedMemRefType::get(elementType, memorySpace); } MemRefType bufferization::getDynamicMemRefType(RankedTensorType tensorType, unsigned addressSpace) { // TODO: address space decisions to connect with the actual alloc. int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; SmallVector dynamicStrides(tensorType.getRank(), ShapedType::kDynamicStrideOrOffset); AffineMap stridedLayout = makeStridedLinearLayoutMap( dynamicStrides, dynamicOffset, tensorType.getContext()); return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), stridedLayout, addressSpace); }