//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::scf; namespace mlir { namespace scf { namespace { // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG auto rankedTensorType = tensor.getType().dyn_cast(); assert((!rankedTensorType || (memrefType.cast().getRank() == rankedTensorType.getRank())) && "to_memref would be invalid: mismatching ranks"); #endif } /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value // is considered to be aliasing with the result. auto executeRegionOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); // TODO: Support multiple blocks. assert(executeRegionOp.getRegion().getBlocks().size() == 1 && "expected exactly 1 block"); auto yieldOp = dyn_cast( executeRegionOp.getRegion().front().getTerminator()); assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); return {&yieldOp->getOpOperand(resultNum)}; } // TODO: For better bufferization results, this could return `true` only if // there is a memory write in the region. bool isMemoryWrite(Operation *op, OpResult opResult, const AnalysisState &state) const { // Similar to scf.if, results of this op are always considered memory writes // in the analysis. This is a useful pattern for all ops that have tensor // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is // implemented in terms of `bufferizesToMemoryWrite`, which does not work on // ops without OpOperands. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto executeRegionOp = cast(op); // Compute new result types. SmallVector newResultTypes; for (Type type : executeRegionOp->getResultTypes()) { if (auto tensorType = type.dyn_cast()) { newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newResultTypes.push_back(type); } } // Create new op and move over region. auto newOp = rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); // Update terminator. assert(newOp.getRegion().getBlocks().size() == 1 && "only 1 block supported"); Block *newBlock = &newOp.getRegion().front(); auto yieldOp = cast(newBlock->getTerminator()); rewriter.setInsertionPoint(yieldOp); SmallVector newYieldValues; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value val = it.value(); if (val.getType().isa()) { newYieldValues.push_back(rewriter.create( yieldOp.getLoc(), newResultTypes[it.index()], val)); } else { newYieldValues.push_back(val); } } rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { if (it.value().isa()) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); } } // Replace old op. rewriter.replaceOp(executeRegionOp, newResults); return success(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::Equivalent; } }; /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else // branches are considered to be aliasing with the result. auto ifOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); return {&ifOp.thenYield()->getOpOperand(resultNum), &ifOp.elseYield()->getOpOperand(resultNum)}; } // TODO: For better bufferization results, this could return `true` only if // there is a memory write in one (or both) of the branches. Since this is not // allowed at the moment, we should never encounter scf.ifs that yield // unmodified tensors. Such scf.yield ops could just fold away. bool isMemoryWrite(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOp results are always considered memory writes in the analysis. This // design decision simplifies the analysis considerably. E.g., consider the // following test case: // // %0 = "some_writing_op" : tensor // %r = scf.if %c -> (tensor) { // scf.yield %0 // } else { // %1 = "another_writing_op"(%0) : tensor // } // "some_reading_op"(%r) // // "another_writing_op" in the above example should be able to bufferize // inplace in the absence of another read of %0. However, if the scf.if op // would not be considered a "write", the analysis would detect the // following conflict: // // * read = some_reading_op // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) // * conflictingWrite = %1 // // For more details, check the "scf.IfOp" section of the design document. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto ifOp = cast(op); // Compute new types of the bufferized scf.if op. SmallVector newTypes; for (Type returnType : ifOp->getResultTypes()) { if (auto tensorType = returnType.dyn_cast()) { newTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newTypes.push_back(returnType); } } // Create new op. auto newIfOp = rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); // Remove terminators. if (!newIfOp.thenBlock()->empty()) { rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); } // Move over then/else blocks. rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); // Update scf.yield of new then-block. auto thenYieldOp = cast(newIfOp.thenBlock()->getTerminator()); rewriter.setInsertionPoint(thenYieldOp); SmallVector thenYieldValues; for (OpOperand &operand : thenYieldOp->getOpOperands()) { if (operand.get().getType().isa()) { ensureToMemrefOpIsValid(operand.get(), newTypes[operand.getOperandNumber()]); Value toMemrefOp = rewriter.create( operand.get().getLoc(), newTypes[operand.getOperandNumber()], operand.get()); operand.set(toMemrefOp); } } // Update scf.yield of new else-block. auto elseYieldOp = cast(newIfOp.elseBlock()->getTerminator()); rewriter.setInsertionPoint(elseYieldOp); SmallVector elseYieldValues; for (OpOperand &operand : elseYieldOp->getOpOperands()) { if (operand.get().getType().isa()) { ensureToMemrefOpIsValid(operand.get(), newTypes[operand.getOperandNumber()]); Value toMemrefOp = rewriter.create( operand.get().getLoc(), newTypes[operand.getOperandNumber()], operand.get()); operand.set(toMemrefOp); } } // Replace op results. replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); return success(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOp results are equivalent to their corresponding yield values if both // yield values are equivalent to each other. auto bufferizableOp = cast(op); SmallVector yieldValues = bufferizableOp.getAliasingOpOperand(opResult, state); assert(yieldValues.size() == 2 && "expected 2 yield values"); bool equivalentYields = state.areEquivalentBufferizedValues( yieldValues[0]->get(), yieldValues[1]->get()); return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; } }; /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto forOp = cast(op); return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::ForOps are always considered as a write. return true; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); return {forOp.getResultForOpOperand(opOperand)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // ForOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent. auto forOp = cast(op); OpOperand &forOperand = forOp.getOpOperandForResult(opResult); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); bool equivalentYield = state.areEquivalentBufferizedValues( bbArg, yieldOp->getOperand(opResult.getResultNumber())); return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Interestingly, scf::ForOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto forOp = cast(op); auto bufferizableOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices; // For every yielded value, is the value equivalent to its corresponding // bbArg? SmallVector equivalentYields; for (const auto &it : llvm::enumerate(forOp.getInitArgs())) { if (it.value().getType().isa()) { indices.insert(it.index()); BufferRelation relation = bufferizableOp.bufferRelation( forOp->getResult(it.index()), state.getAnalysisState()); equivalentYields.push_back(relation == BufferRelation::Equivalent); } else { equivalentYields.push_back(false); } } // Given a range of values, apply `func` to those marked in `indices`. // Otherwise, store the unmodified value in the result vector. auto convert = [&](ValueRange values, llvm::function_ref func) { SmallVector result; for (const auto &it : llvm::enumerate(values)) { size_t idx = it.index(); Value val = it.value(); result.push_back(indices.contains(idx) ? func(val, idx) : val); } return result; }; // Construct a new scf.for op with memref instead of tensor values. SmallVector initArgs; for (OpOperand &opOperand : forOp.getIterOpOperands()) { if (opOperand.get().getType().isa()) { FailureOr resultBuffer = state.getBuffer(rewriter, opOperand); if (failed(resultBuffer)) return failure(); initArgs.push_back(*resultBuffer); } else { initArgs.push_back(opOperand.get()); } } auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), initArgs); Block *loopBody = &newForOp.getLoopBody().front(); // Set up new iter_args. The loop body uses tensors, so wrap the (memref) // iter_args of the new loop in ToTensorOps. rewriter.setInsertionPointToStart(loopBody); SmallVector iterArgs = convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { return rewriter.create(val.getLoc(), val); }); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); // Erase terminator if present. if (iterArgs.size() == 1) rewriter.eraseOp(loopBody->getTerminator()); // Move loop body to new loop. rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); // Update scf.yield of new loop. auto yieldOp = cast(loopBody->getTerminator()); rewriter.setInsertionPoint(yieldOp); SmallVector yieldValues = convert(yieldOp.getResults(), [&](Value val, int64_t index) { ensureToMemrefOpIsValid(val, initArgs[index].getType()); Value yieldedVal = rewriter.create( val.getLoc(), initArgs[index].getType(), val); if (equivalentYields[index]) // Yielded value is equivalent to the corresponding iter_arg bbArg. // Yield the value directly. Most IR should be like that. Everything // else must be resolved with copies and is potentially inefficient. // By default, such problematic IR would already have been rejected // during `verifyAnalysis`, unless `allow-return-allocs`. return yieldedVal; // It is not certain that the yielded value and the iter_arg bbArg // have the same buffer. Allocate a new buffer and copy. The yielded // buffer will get deallocated by `deallocateBuffers`. // TODO: There are cases in which it is not neccessary to return a new // buffer allocation. E.g., when equivalent values are yielded in a // different order. This could be resolved with copies. Optional yieldedAlloc = state.createAlloc( rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false); // TODO: We should rollback, but for now just assume that this always // succeeds. assert(yieldedAlloc.hasValue() && "could not create alloc"); LogicalResult copyStatus = bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions()); (void)copyStatus; assert(succeeded(copyStatus) && "could not create memcpy"); if (yieldedVal.getType() == yieldedAlloc->getType()) return *yieldedAlloc; // The iter_arg memref type has a layout map. Cast the new buffer to // the same type. // TODO: In case the iter_arg has a layout map that is not the fully // dynamic one, we cannot cast the new buffer. In that case, the // iter_arg must be changed to the fully dynamic layout map. (And then // the new buffer can be casted.) assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(), yieldedVal.getType()) && "scf.for op bufferization: cast incompatible"); Value casted = rewriter.create( val.getLoc(), yieldedVal.getType(), *yieldedAlloc); return casted; }); yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); return success(); } /// Assert that yielded values of an scf.for op are equivalent to their /// corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded /// from the loop. This could be a performance problem, so it must be /// explicitly activated with `alloc-return-allocs`. LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { const auto &options = static_cast(state.getOptions()); if (options.allowReturnAllocs) return success(); auto forOp = cast(op); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); for (OpOperand &operand : yieldOp->getOpOperands()) { auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) continue; OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); // Note: This is overly strict. We should check for aliasing bufferized // values. But we don't have a "must-alias" analysis yet. if (!state.areEquivalentBufferizedValues(operand.get(), bbArg)) return yieldOp->emitError() << "Yield operand #" << operand.getOperandNumber() << " does not bufferize to a buffer that is aliasing the " "matching enclosing scf::for operand"; } return success(); } }; /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { if (isa(op->getParentOp())) return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; if (isa(op->getParentOp())) return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; return {}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Yield operands always bufferize inplace. Otherwise, an alloc + copy // may be generated inside the block. We should not return/yield allocations // when possible. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto yieldOp = cast(op); if (!isa( yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); return success(); } }; } // namespace } // namespace scf } // namespace mlir void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); }); }