//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::scf; namespace mlir { namespace scf { namespace { // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG auto rankedTensorType = tensor.getType().dyn_cast(); assert((!rankedTensorType || (memrefType.cast().getRank() == rankedTensorType.getRank())) && "to_memref would be invalid: mismatching ranks"); #endif } /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value // is considered to be aliasing with the result. auto executeRegionOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); // TODO: Support multiple blocks. assert(executeRegionOp.getRegion().getBlocks().size() == 1 && "expected exactly 1 block"); auto yieldOp = dyn_cast( executeRegionOp.getRegion().front().getTerminator()); assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); return {&yieldOp->getOpOperand(resultNum)}; } // TODO: For better bufferization results, this could return `true` only if // there is a memory write in the region. bool isMemoryWrite(Operation *op, OpResult opResult, const AnalysisState &state) const { // Similar to scf.if, results of this op are always considered memory writes // in the analysis. This is a useful pattern for all ops that have tensor // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is // implemented in terms of `bufferizesToMemoryWrite`, which does not work on // ops without OpOperands. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); assert(executeRegionOp.getRegion().getBlocks().size() == 1 && "only 1 block supported"); auto yieldOp = cast(executeRegionOp.getRegion().front().getTerminator()); TypeRange newResultTypes(yieldOp.getResults()); // Create new op and move over region. auto newOp = rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { if (it.value().isa()) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); } } // Replace old op. rewriter.replaceOp(executeRegionOp, newResults); return success(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::Equivalent; } }; /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else // branches are considered to be aliasing with the result. auto ifOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); return {&ifOp.thenYield()->getOpOperand(resultNum), &ifOp.elseYield()->getOpOperand(resultNum)}; } // TODO: For better bufferization results, this could return `true` only if // there is a memory write in one (or both) of the branches. Since this is not // allowed at the moment, we should never encounter scf.ifs that yield // unmodified tensors. Such scf.yield ops could just fold away. bool isMemoryWrite(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOp results are always considered memory writes in the analysis. This // design decision simplifies the analysis considerably. E.g., consider the // following test case: // // %0 = "some_writing_op" : tensor // %r = scf.if %c -> (tensor) { // scf.yield %0 // } else { // %1 = "another_writing_op"(%0) : tensor // } // "some_reading_op"(%r) // // "another_writing_op" in the above example should be able to bufferize // inplace in the absence of another read of %0. However, if the scf.if op // would not be considered a "write", the analysis would detect the // following conflict: // // * read = some_reading_op // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) // * conflictingWrite = %1 // // For more details, check the "scf.IfOp" section of the design document. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast(op); auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); // Reconcile type mismatches between then/else branches by inserting memref // casts. SmallVector thenResults, elseResults; bool insertedCast = false; for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) { Value thenValue = thenYieldOp.getResults()[i]; Value elseValue = elseYieldOp.getResults()[i]; if (thenValue.getType() == elseValue.getType()) { thenResults.push_back(thenValue); elseResults.push_back(elseValue); continue; } // Type mismatch between then/else yield value. Cast both to a memref type // with a fully dynamic layout map. auto thenMemrefType = thenValue.getType().cast(); auto elseMemrefType = elseValue.getType().cast(); if (thenMemrefType.getMemorySpaceAsInt() != elseMemrefType.getMemorySpaceAsInt()) return op->emitError("inconsistent memory space on then/else branches"); rewriter.setInsertionPoint(thenYieldOp); BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout( ifOp.getResultTypes()[i].cast(), thenMemrefType.getMemorySpaceAsInt()); thenResults.push_back(rewriter.create( thenYieldOp.getLoc(), memrefType, thenValue)); rewriter.setInsertionPoint(elseYieldOp); elseResults.push_back(rewriter.create( elseYieldOp.getLoc(), memrefType, elseValue)); insertedCast = true; } if (insertedCast) { rewriter.setInsertionPoint(thenYieldOp); rewriter.replaceOpWithNewOp(thenYieldOp, thenResults); rewriter.setInsertionPoint(elseYieldOp); rewriter.replaceOpWithNewOp(elseYieldOp, elseResults); } // Create new op. rewriter.setInsertionPoint(ifOp); ValueRange resultsValueRange(thenResults); TypeRange newTypes(resultsValueRange); auto newIfOp = rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); // Move over then/else blocks. rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); // Replace op results. replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); return success(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOp results are equivalent to their corresponding yield values if both // yield values are equivalent to each other. auto bufferizableOp = cast(op); SmallVector yieldValues = bufferizableOp.getAliasingOpOperand(opResult, state); assert(yieldValues.size() == 2 && "expected 2 yield values"); bool equivalentYields = state.areEquivalentBufferizedValues( yieldValues[0]->get(), yieldValues[1]->get()); return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; } }; /// Helper function for loop bufferization. Return the indices of all values /// that have a tensor type. static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) if (it.value().getType().isa()) result.insert(it.index()); return result; } /// Helper function for loop bufferization. Return the indices of all /// bbArg/yielded value pairs who's buffer relation is "Equivalent". DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, ValueRange yieldedValues, const AnalysisState &state) { unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { if (!bbArgs[i].getType().isa() || !yieldedValues[i].getType().isa()) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); } return result; } /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { assert(type.isa() && "expected BaseMemRefType"); assert(buffer.getType().isa() && "expected BaseMemRefType"); // If the buffer already has the correct type, no cast is needed. if (buffer.getType() == type) return buffer; // TODO: In case `type` has a layout map that is not the fully dynamic // one, we may not be able to cast the buffer. In that case, the loop // iter_arg's layout map must be changed (see uses of `castBuffer`). assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && "scf.while op bufferization: cast incompatible"); return b.create(buffer.getLoc(), type, buffer).getResult(); } /// Helper function for loop bufferization. Return the bufferized values of the /// given OpOperands. If an operand is not a tensor, return the original value. static FailureOr> getBuffers(RewriterBase &rewriter, MutableArrayRef operands, const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { if (opOperand.get().getType().isa()) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options); if (failed(resultBuffer)) return failure(); result.push_back(*resultBuffer); } else { result.push_back(opOperand.get()); } } return result; } /// Helper function for loop bufferization. Compute the buffer that should be /// yielded from a loop block (loop body or loop condition). static FailureOr getYieldedBuffer(RewriterBase &rewriter, Value tensor, BaseMemRefType type, const BufferizationOptions &options) { assert(tensor.getType().isa() && "expected tensor"); ensureToMemrefOpIsValid(tensor, type); FailureOr yieldedVal = getBuffer(rewriter, tensor, options); if (failed(yieldedVal)) return failure(); return castBuffer(rewriter, *yieldedVal, type); } /// Helper function for loop bufferization. Given a range of values, apply /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified /// value in the result vector. static FailureOr> convertTensorValues(ValueRange values, const DenseSet &tensorIndices, llvm::function_ref(Value, int64_t)> func) { SmallVector result; for (const auto &it : llvm::enumerate(values)) { size_t idx = it.index(); Value val = it.value(); if (tensorIndices.contains(idx)) { FailureOr maybeVal = func(val, idx); if (failed(maybeVal)) return failure(); result.push_back(*maybeVal); } else { result.push_back(val); } } return result; } /// Helper function for loop bufferization. Given a list of pre-bufferization /// yielded values, compute the list of bufferized yielded values. FailureOr> getYieldedValues(RewriterBase &rewriter, ValueRange values, TypeRange bufferizedTypes, const DenseSet &tensorIndices, const BufferizationOptions &options) { return convertTensorValues( values, tensorIndices, [&](Value val, int64_t index) { return getYieldedBuffer(rewriter, val, bufferizedTypes[index].cast(), options); }); } /// Helper function for loop bufferization. Given a list of bbArgs of the new /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into /// ToTensorOps, so that the block body can be moved over to the new op. SmallVector getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, const DenseSet &tensorIndices) { SmallVector result; for (const auto &it : llvm::enumerate(bbArgs)) { size_t idx = it.index(); Value val = it.value(); if (tensorIndices.contains(idx)) { result.push_back( rewriter.create(val.getLoc(), val) .getResult()); } else { result.push_back(val); } } return result; } /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto forOp = cast(op); return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::ForOps are always considered as a write. return true; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); return {forOp.getResultForOpOperand(opOperand)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // ForOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent. auto forOp = cast(op); OpOperand &forOperand = forOp.getOpOperandForResult(opResult); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); bool equivalentYield = state.areEquivalentBufferizedValues( bbArg, yieldOp->getOperand(opResult.getResultNumber())); return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Interestingly, scf::ForOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. return true; } LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, const AnalysisState &state) const { auto bufferizableOp = cast(op); if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) return failure(); if (!state.getOptions().enforceAliasingInvariants) return success(); // According to the `getAliasing...` implementations, a bufferized OpResult // may alias only with the corresponding bufferized init_arg and with no // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; // but not with any other OpOperand. If a corresponding OpResult/init_arg // pair bufferizes to equivalent buffers, this aliasing requirement is // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. // (New buffer copies do not alias with any buffer.) auto forOp = cast(op); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(yieldOp); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices = getTensorIndices(forOp.getInitArgs()); // For every yielded value, is the value equivalent to its corresponding // bbArg? DenseSet equivalentYields = getEquivalentBuffers( forOp.getRegionIterArgs(), yieldOp.getResults(), state); SmallVector yieldValues; for (int64_t idx = 0; idx < static_cast(yieldOp.getResults().size()); ++idx) { Value value = yieldOp.getResults()[idx]; if (!indices.contains(idx) || equivalentYields.contains(idx)) { yieldValues.push_back(value); continue; } FailureOr alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value, /*escape=*/true, state.getOptions()); if (failed(alloc)) return failure(); yieldValues.push_back(*alloc); } rewriter.updateRootInPlace( yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); return success(); } FailureOr getBufferType(Operation *op, BlockArgument bbArg, const BufferizationOptions &options) const { auto forOp = cast(op); return bufferization::getBufferType( forOp.getOpOperandForRegionIterArg(bbArg).get(), options); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto forOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices = getTensorIndices(forOp.getInitArgs()); // The new memref init_args of the loop. FailureOr> maybeInitArgs = getBuffers(rewriter, forOp.getIterOpOperands(), options); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; // Construct a new scf.for op with memref instead of tensor values. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), initArgs); newForOp->setAttrs(forOp->getAttrs()); ValueRange initArgsRange(initArgs); TypeRange initArgsTypes(initArgsRange); Block *loopBody = &newForOp.getLoopBody().front(); // Set up new iter_args. The loop body uses tensors, so wrap the (memref) // iter_args of the new loop in ToTensorOps. rewriter.setInsertionPointToStart(loopBody); SmallVector iterArgs = getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); // Move loop body to new loop. rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); return success(); } /// Assert that yielded values of an scf.for op are equivalent to their /// corresponding bbArgs. In that case, the buffer relations of the /// corresponding OpResults are "Equivalent". /// /// If this is not the case, an allocs+copies are inserted and yielded from /// the loop. This could be a performance problem, so it must be explicitly /// activated with `alloc-return-allocs`. LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { const auto &options = static_cast(state.getOptions()); if (options.allowReturnAllocs) return success(); auto forOp = cast(op); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); for (OpResult opResult : op->getOpResults()) { if (!opResult.getType().isa()) continue; // Note: This is overly strict. We should check for aliasing bufferized // values. But we don't have a "must-alias" analysis yet. if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) return yieldOp->emitError() << "Yield operand #" << opResult.getResultNumber() << " is not equivalent to the corresponding iter bbArg"; } return success(); } }; /// Bufferization of scf.while. Replace with a new scf.while that operates on /// memrefs. struct WhileOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::WhileOps are always considered as a read. return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::WhileOps are always considered as a write. return true; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto whileOp = cast(op); unsigned int idx = opOperand.getOperandNumber(); // The OpResults and OpOperands may not match. They may not even have the // same type. The number of OpResults and OpOperands can also differ. if (idx >= op->getNumResults() || opOperand.get().getType() != op->getResult(idx).getType()) return {}; // The only aliasing OpResult may be the one at the same index. return {whileOp->getResult(idx)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // WhileOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent (for both the // "before" and the "after" block). unsigned int resultNumber = opResult.getResultNumber(); auto whileOp = cast(op); // The "before" region bbArgs and the OpResults may not match. if (resultNumber >= whileOp.getBeforeArguments().size()) return BufferRelation::None; if (opResult.getType() != whileOp.getBeforeArguments()[resultNumber].getType()) return BufferRelation::None; auto conditionOp = whileOp.getConditionOp(); BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; Value conditionOperand = conditionOp.getArgs()[resultNumber]; bool equivCondition = state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); auto yieldOp = whileOp.getYieldOp(); BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; Value yieldOperand = yieldOp.getOperand(resultNumber); bool equivYield = state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); return equivCondition && equivYield ? BufferRelation::Equivalent : BufferRelation::None; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Interestingly, scf::WhileOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. return true; } LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, const AnalysisState &state) const { auto bufferizableOp = cast(op); if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) return failure(); if (!state.getOptions().enforceAliasingInvariants) return success(); // According to the `getAliasing...` implementations, a bufferized OpResult // may alias only with the corresponding bufferized init_arg and with no // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; // but not with any other OpOperand. If a corresponding OpResult/init_arg // pair bufferizes to equivalent buffers, this aliasing requirement is // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. // (New buffer copies do not alias with any buffer.) OpBuilder::InsertionGuard g(rewriter); auto whileOp = cast(op); auto conditionOp = whileOp.getConditionOp(); auto yieldOp = whileOp.getYieldOp(); // Indices of all bbArgs that have tensor type. These are the ones that // are bufferized. The "before" and "after" regions may have different args. DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); DenseSet indicesAfter = getTensorIndices(whileOp.getAfterArguments()); // For every yielded value, is the value equivalent to its corresponding // bbArg? DenseSet equivalentYieldsBefore = getEquivalentBuffers( whileOp.getBeforeArguments(), conditionOp.getArgs(), state); DenseSet equivalentYieldsAfter = getEquivalentBuffers( whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); // Update "before" region. rewriter.setInsertionPoint(conditionOp); SmallVector beforeYieldValues; for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; if (!indicesBefore.contains(idx) || equivalentYieldsBefore.contains(idx)) { beforeYieldValues.push_back(value); continue; } FailureOr alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value, /*escape=*/true, state.getOptions()); if (failed(alloc)) return failure(); beforeYieldValues.push_back(*alloc); } rewriter.updateRootInPlace(conditionOp, [&]() { conditionOp.getArgsMutable().assign(beforeYieldValues); }); // Update "after" region. rewriter.setInsertionPoint(yieldOp); SmallVector afterYieldValues; for (int64_t idx = 0; idx < static_cast(yieldOp.getResults().size()); ++idx) { Value value = yieldOp.getResults()[idx]; if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { afterYieldValues.push_back(value); continue; } FailureOr alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value, /*escape=*/true, state.getOptions()); if (failed(alloc)) return failure(); afterYieldValues.push_back(*alloc); } rewriter.updateRootInPlace(yieldOp, [&]() { yieldOp.getResultsMutable().assign(afterYieldValues); }); return success(); } // TODO: Implement getBufferType interface method and infer buffer types. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto whileOp = cast(op); assert(whileOp.getBefore().getBlocks().size() == 1 && "regions with multiple blocks not supported"); Block *beforeBody = &whileOp.getBefore().front(); assert(whileOp.getAfter().getBlocks().size() == 1 && "regions with multiple blocks not supported"); Block *afterBody = &whileOp.getAfter().front(); // Indices of all bbArgs that have tensor type. These are the ones that // are bufferized. The "before" and "after" regions may have different args. DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); DenseSet indicesAfter = getTensorIndices(whileOp.getAfterArguments()); // The new memref init_args of the loop. FailureOr> maybeInitArgs = getBuffers(rewriter, whileOp->getOpOperands(), options); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { // TODO: error handling return bufferization::getBufferType(bbArg, options)->cast(); })); // Construct a new scf.while op with memref instead of tensor values. ValueRange argsRangeBefore(initArgs); TypeRange argsTypesBefore(argsRangeBefore); auto newWhileOp = rewriter.create(whileOp.getLoc(), argsTypesAfter, initArgs); // Add before/after regions to the new op. SmallVector bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); SmallVector bbArgLocsAfter(argsTypesAfter.size(), whileOp.getLoc()); Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); // Set up new iter_args and move the loop condition block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. rewriter.setInsertionPointToStart(newBeforeBody); SmallVector newBeforeArgs = getBbArgReplacements( rewriter, newWhileOp.getBeforeArguments(), indicesBefore); rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); // Update scf.condition of new loop. auto newConditionOp = newWhileOp.getConditionOp(); rewriter.setInsertionPoint(newConditionOp); // Only equivalent buffers or new buffer allocations may be yielded to the // "after" region. // TODO: This could be relaxed for better bufferization results. FailureOr> newConditionArgs = getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, indicesAfter, options); if (failed(newConditionArgs)) return failure(); newConditionOp.getArgsMutable().assign(*newConditionArgs); // Set up new iter_args and move the loop body block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. rewriter.setInsertionPointToStart(newAfterBody); SmallVector newAfterArgs = getBbArgReplacements( rewriter, newWhileOp.getAfterArguments(), indicesAfter); rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); // Update scf.yield of the new loop. auto newYieldOp = newWhileOp.getYieldOp(); rewriter.setInsertionPoint(newYieldOp); // Only equivalent buffers or new buffer allocations may be yielded to the // "before" region. // TODO: This could be relaxed for better bufferization results. FailureOr> newYieldValues = getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, indicesBefore, options); if (failed(newYieldValues)) return failure(); newYieldOp.getResultsMutable().assign(*newYieldValues); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); return success(); } /// Assert that yielded values of an scf.while op are equivalent to their /// corresponding bbArgs. In that case, the buffer relations of the /// corresponding OpResults are "Equivalent". /// /// If this is not the case, allocs+copies are inserted and yielded from /// the loop. This could be a performance problem, so it must be explicitly /// activated with `alloc-return-allocs`. /// /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the /// equivalence condition must be checked for both. LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { auto whileOp = cast(op); const auto &options = static_cast(state.getOptions()); if (options.allowReturnAllocs) return success(); auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { if (!it.value().getType().isa()) continue; if (!state.areEquivalentBufferizedValues( it.value(), conditionOp->getBlock()->getArgument(it.index()))) return conditionOp->emitError() << "Condition arg #" << it.index() << " is not equivalent to the corresponding iter bbArg"; } auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { if (!it.value().getType().isa()) continue; if (!state.areEquivalentBufferizedValues( it.value(), yieldOp->getBlock()->getArgument(it.index()))) return yieldOp->emitError() << "Yield operand #" << it.index() << " is not equivalent to the corresponding iter bbArg"; } return success(); } }; /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { if (isa(op->getParentOp())) return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; if (isa(op->getParentOp())) return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; return {}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Yield operands always bufferize inplace. Otherwise, an alloc + copy // may be generated inside the block. We should not return/yield allocations // when possible. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto yieldOp = cast(op); if (!isa( yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized // together with scf.while.) if (isa(yieldOp->getParentOp())) return success(); SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); if (value.getType().isa()) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; if (auto forOp = dyn_cast(yieldOp->getParentOp())) { FailureOr resultType = cast(forOp.getOperation()) .getBufferType(forOp.getRegionIterArgs()[it.index()], options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } newResults.push_back(buffer); } else { newResults.push_back(value); } } replaceOpWithNewBufferizedOp(rewriter, op, newResults); return success(); } }; /// Return the destinations that an ForeachThreadOp is inserting into. One per /// ParallelInsertSliceOp. static SmallVector getInsertionDest(ForeachThreadOp foreachThreadOp) { PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); SmallVector result; terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) { result.push_back(&insertOp->getOpOperand(1) /*dest*/); }); return result; } /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the /// region. There are op interfaces for the terminators (PerformConcurrentlyOp /// and ParallelInsertSliceOp), but these are only used during analysis. Not /// for bufferization. struct ForeachThreadOpInterface : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { // Get OpOperand (dest) from corresponding ParallelInsertSliceOp. auto foreachThreadOp = cast(op); return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]}; } bool isMemoryWrite(Operation *op, OpResult opResult, const AnalysisState &state) const { // This op is a memory write. Stop lookup here to avoid finding false // conflicts involving this op and one of the ops in the region. This is // similar to how scf.if ops are analyzed. return true; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto foreachThreadOp = cast(op); #ifndef NDEBUG // ParallelInsertSliceOpInterface replaces all uses. for (OpResult opResult : foreachThreadOp->getOpResults()) assert(opResult.getUses().empty() && "expected that all uses were already replaced"); #endif // NDEBUG // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. TypeRange newResultTypes; auto newForeachThreadOp = rewriter.create( foreachThreadOp.getLoc(), newResultTypes, foreachThreadOp.getNumThreads(), extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. rewriter.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(), {newForeachThreadOp.getBody()->getArguments()}); // Remove the old op. rewriter.eraseOp(op); return success(); } }; /// Nothing to do for PerformConcurrentlyOp. struct PerformConcurrentlyOpInterface : public BufferizableOpInterface::ExternalModel< PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { LogicalResult bufferize(Operation *op, RewriterBase &b, const BufferizationOptions &options) const { llvm_unreachable("op does not have any tensor OpOperands / OpResults"); return failure(); } }; } // namespace } // namespace scf } // namespace mlir void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); ForeachThreadOp::attachInterface(*ctx); PerformConcurrentlyOp::attachInterface( *ctx); WhileOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); }); }