119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 219efe141SMatthias Springer // 319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 619efe141SMatthias Springer // 719efe141SMatthias Springer //===----------------------------------------------------------------------===// 819efe141SMatthias Springer 919efe141SMatthias Springer #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 1019efe141SMatthias Springer 1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1319efe141SMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1419efe141SMatthias Springer #include "mlir/IR/Dialect.h" 1519efe141SMatthias Springer #include "mlir/IR/Operation.h" 1619efe141SMatthias Springer #include "mlir/IR/PatternMatch.h" 1719efe141SMatthias Springer 1819efe141SMatthias Springer using namespace mlir; 1919efe141SMatthias Springer using namespace mlir::bufferization; 2019efe141SMatthias Springer using namespace mlir::scf; 2119efe141SMatthias Springer 2219efe141SMatthias Springer namespace mlir { 2319efe141SMatthias Springer namespace scf { 2419efe141SMatthias Springer namespace { 2519efe141SMatthias Springer 2619efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank. 2719efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 2819efe141SMatthias Springer #ifndef NDEBUG 2919efe141SMatthias Springer auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 3019efe141SMatthias Springer assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 3119efe141SMatthias Springer rankedTensorType.getRank())) && 3219efe141SMatthias Springer "to_memref would be invalid: mismatching ranks"); 3319efe141SMatthias Springer #endif 3419efe141SMatthias Springer } 3519efe141SMatthias Springer 3619efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 3719efe141SMatthias Springer /// fully implemented at the moment. 3819efe141SMatthias Springer struct ExecuteRegionOpInterface 3919efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 4019efe141SMatthias Springer scf::ExecuteRegionOp> { 4119efe141SMatthias Springer SmallVector<OpOperand *> 4219efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 4319efe141SMatthias Springer const BufferizationState &state) const { 4419efe141SMatthias Springer // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 4519efe141SMatthias Springer // any SSA value that is in scope. To allow for use-def chain traversal 4619efe141SMatthias Springer // through ExecuteRegionOps in the analysis, the corresponding yield value 4719efe141SMatthias Springer // is considered to be aliasing with the result. 4819efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 4919efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(), 5019efe141SMatthias Springer llvm::find(op->getOpResults(), opResult)); 5119efe141SMatthias Springer // TODO: Support multiple blocks. 5219efe141SMatthias Springer assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 5319efe141SMatthias Springer "expected exactly 1 block"); 5419efe141SMatthias Springer auto yieldOp = dyn_cast<scf::YieldOp>( 5519efe141SMatthias Springer executeRegionOp.getRegion().front().getTerminator()); 5619efe141SMatthias Springer assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 5719efe141SMatthias Springer return {&yieldOp->getOpOperand(resultNum)}; 5819efe141SMatthias Springer } 5919efe141SMatthias Springer 6019efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if 6119efe141SMatthias Springer // there is a memory write in the region. 6219efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult, 6319efe141SMatthias Springer const BufferizationState &state) const { 6419efe141SMatthias Springer // Similar to scf.if, results of this op are always considered memory writes 6519efe141SMatthias Springer // in the analysis. This is a useful pattern for all ops that have tensor 6619efe141SMatthias Springer // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 6719efe141SMatthias Springer // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 6819efe141SMatthias Springer // ops without OpOperands. 6919efe141SMatthias Springer return true; 7019efe141SMatthias Springer } 7119efe141SMatthias Springer 7219efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 7319efe141SMatthias Springer const BufferizationState &state) const { 7419efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 7519efe141SMatthias Springer 7619efe141SMatthias Springer // Compute new result types. 7719efe141SMatthias Springer SmallVector<Type> newResultTypes; 7819efe141SMatthias Springer for (Type type : executeRegionOp->getResultTypes()) { 7919efe141SMatthias Springer if (auto tensorType = type.dyn_cast<TensorType>()) { 8019efe141SMatthias Springer newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 8119efe141SMatthias Springer } else { 8219efe141SMatthias Springer newResultTypes.push_back(type); 8319efe141SMatthias Springer } 8419efe141SMatthias Springer } 8519efe141SMatthias Springer 8619efe141SMatthias Springer // Create new op and move over region. 8719efe141SMatthias Springer auto newOp = 8819efe141SMatthias Springer rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 8919efe141SMatthias Springer newOp.getRegion().takeBody(executeRegionOp.getRegion()); 9019efe141SMatthias Springer 9119efe141SMatthias Springer // Update terminator. 9219efe141SMatthias Springer assert(newOp.getRegion().getBlocks().size() == 1 && 9319efe141SMatthias Springer "only 1 block supported"); 9419efe141SMatthias Springer Block *newBlock = &newOp.getRegion().front(); 9519efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 9619efe141SMatthias Springer rewriter.setInsertionPoint(yieldOp); 9719efe141SMatthias Springer SmallVector<Value> newYieldValues; 98bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 9919efe141SMatthias Springer Value val = it.value(); 10019efe141SMatthias Springer if (val.getType().isa<TensorType>()) { 10119efe141SMatthias Springer newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 10219efe141SMatthias Springer yieldOp.getLoc(), newResultTypes[it.index()], val)); 10319efe141SMatthias Springer } else { 10419efe141SMatthias Springer newYieldValues.push_back(val); 10519efe141SMatthias Springer } 10619efe141SMatthias Springer } 10719efe141SMatthias Springer rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 10819efe141SMatthias Springer 10919efe141SMatthias Springer // Update all uses of the old op. 11019efe141SMatthias Springer rewriter.setInsertionPointAfter(newOp); 11119efe141SMatthias Springer SmallVector<Value> newResults; 112bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 11319efe141SMatthias Springer if (it.value().isa<TensorType>()) { 11419efe141SMatthias Springer newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 11519efe141SMatthias Springer executeRegionOp.getLoc(), newOp->getResult(it.index()))); 11619efe141SMatthias Springer } else { 11719efe141SMatthias Springer newResults.push_back(newOp->getResult(it.index())); 11819efe141SMatthias Springer } 11919efe141SMatthias Springer } 12019efe141SMatthias Springer 12119efe141SMatthias Springer // Replace old op. 12219efe141SMatthias Springer rewriter.replaceOp(executeRegionOp, newResults); 12319efe141SMatthias Springer 12419efe141SMatthias Springer return success(); 12519efe141SMatthias Springer } 12619efe141SMatthias Springer 12719efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 12819efe141SMatthias Springer const BufferizationState &state) const { 12919efe141SMatthias Springer return BufferRelation::Equivalent; 13019efe141SMatthias Springer } 13119efe141SMatthias Springer }; 13219efe141SMatthias Springer 13319efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 13419efe141SMatthias Springer struct IfOpInterface 13519efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 13619efe141SMatthias Springer SmallVector<OpOperand *> 13719efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 13819efe141SMatthias Springer const BufferizationState &state) const { 13919efe141SMatthias Springer // IfOps do not have tensor OpOperands. The yielded value can be any SSA 14019efe141SMatthias Springer // value that is in scope. To allow for use-def chain traversal through 14119efe141SMatthias Springer // IfOps in the analysis, both corresponding yield values from the then/else 14219efe141SMatthias Springer // branches are considered to be aliasing with the result. 14319efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op); 14419efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(), 14519efe141SMatthias Springer llvm::find(op->getOpResults(), opResult)); 14619efe141SMatthias Springer return {&ifOp.thenYield()->getOpOperand(resultNum), 14719efe141SMatthias Springer &ifOp.elseYield()->getOpOperand(resultNum)}; 14819efe141SMatthias Springer } 14919efe141SMatthias Springer 15019efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if 15119efe141SMatthias Springer // there is a memory write in one (or both) of the branches. Since this is not 15219efe141SMatthias Springer // allowed at the moment, we should never encounter scf.ifs that yield 15319efe141SMatthias Springer // unmodified tensors. Such scf.yield ops could just fold away. 15419efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult, 15519efe141SMatthias Springer const BufferizationState &state) const { 15619efe141SMatthias Springer // IfOp results are always considered memory writes in the analysis. This 15719efe141SMatthias Springer // design decision simplifies the analysis considerably. E.g., consider the 15819efe141SMatthias Springer // following test case: 15919efe141SMatthias Springer // 16019efe141SMatthias Springer // %0 = "some_writing_op" : tensor<?xf32> 16119efe141SMatthias Springer // %r = scf.if %c -> (tensor<?xf32>) { 16219efe141SMatthias Springer // scf.yield %0 16319efe141SMatthias Springer // } else { 16419efe141SMatthias Springer // %1 = "another_writing_op"(%0) : tensor<?xf32> 16519efe141SMatthias Springer // } 16619efe141SMatthias Springer // "some_reading_op"(%r) 16719efe141SMatthias Springer // 16819efe141SMatthias Springer // "another_writing_op" in the above example should be able to bufferize 16919efe141SMatthias Springer // inplace in the absence of another read of %0. However, if the scf.if op 17019efe141SMatthias Springer // would not be considered a "write", the analysis would detect the 17119efe141SMatthias Springer // following conflict: 17219efe141SMatthias Springer // 17319efe141SMatthias Springer // * read = some_reading_op 17419efe141SMatthias Springer // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 17519efe141SMatthias Springer // * conflictingWrite = %1 17619efe141SMatthias Springer // 17719efe141SMatthias Springer // For more details, check the "scf.IfOp" section of the design document. 17819efe141SMatthias Springer return true; 17919efe141SMatthias Springer } 18019efe141SMatthias Springer 18119efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 18219efe141SMatthias Springer const BufferizationState &state) const { 18319efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op); 18419efe141SMatthias Springer 18519efe141SMatthias Springer // Compute new types of the bufferized scf.if op. 18619efe141SMatthias Springer SmallVector<Type> newTypes; 18719efe141SMatthias Springer for (Type returnType : ifOp->getResultTypes()) { 18819efe141SMatthias Springer if (auto tensorType = returnType.dyn_cast<TensorType>()) { 18919efe141SMatthias Springer newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 19019efe141SMatthias Springer } else { 19119efe141SMatthias Springer newTypes.push_back(returnType); 19219efe141SMatthias Springer } 19319efe141SMatthias Springer } 19419efe141SMatthias Springer 19519efe141SMatthias Springer // Create new op. 19619efe141SMatthias Springer auto newIfOp = 19719efe141SMatthias Springer rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 19819efe141SMatthias Springer /*withElseRegion=*/true); 19919efe141SMatthias Springer 20019efe141SMatthias Springer // Remove terminators. 20119efe141SMatthias Springer if (!newIfOp.thenBlock()->empty()) { 20219efe141SMatthias Springer rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 20319efe141SMatthias Springer rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 20419efe141SMatthias Springer } 20519efe141SMatthias Springer 20619efe141SMatthias Springer // Move over then/else blocks. 20719efe141SMatthias Springer rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 20819efe141SMatthias Springer rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 20919efe141SMatthias Springer 21019efe141SMatthias Springer // Update scf.yield of new then-block. 21119efe141SMatthias Springer auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 21219efe141SMatthias Springer rewriter.setInsertionPoint(thenYieldOp); 21319efe141SMatthias Springer SmallVector<Value> thenYieldValues; 21419efe141SMatthias Springer for (OpOperand &operand : thenYieldOp->getOpOperands()) { 21519efe141SMatthias Springer if (operand.get().getType().isa<TensorType>()) { 21619efe141SMatthias Springer ensureToMemrefOpIsValid(operand.get(), 21719efe141SMatthias Springer newTypes[operand.getOperandNumber()]); 21819efe141SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 21919efe141SMatthias Springer operand.get().getLoc(), newTypes[operand.getOperandNumber()], 22019efe141SMatthias Springer operand.get()); 22119efe141SMatthias Springer operand.set(toMemrefOp); 22219efe141SMatthias Springer } 22319efe141SMatthias Springer } 22419efe141SMatthias Springer 22519efe141SMatthias Springer // Update scf.yield of new else-block. 22619efe141SMatthias Springer auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 22719efe141SMatthias Springer rewriter.setInsertionPoint(elseYieldOp); 22819efe141SMatthias Springer SmallVector<Value> elseYieldValues; 22919efe141SMatthias Springer for (OpOperand &operand : elseYieldOp->getOpOperands()) { 23019efe141SMatthias Springer if (operand.get().getType().isa<TensorType>()) { 23119efe141SMatthias Springer ensureToMemrefOpIsValid(operand.get(), 23219efe141SMatthias Springer newTypes[operand.getOperandNumber()]); 23319efe141SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 23419efe141SMatthias Springer operand.get().getLoc(), newTypes[operand.getOperandNumber()], 23519efe141SMatthias Springer operand.get()); 23619efe141SMatthias Springer operand.set(toMemrefOp); 23719efe141SMatthias Springer } 23819efe141SMatthias Springer } 23919efe141SMatthias Springer 24019efe141SMatthias Springer // Replace op results. 24119efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 24219efe141SMatthias Springer 24319efe141SMatthias Springer return success(); 24419efe141SMatthias Springer } 24519efe141SMatthias Springer 24619efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 24719efe141SMatthias Springer const BufferizationState &state) const { 24819efe141SMatthias Springer // IfOp results are equivalent to their corresponding yield values if both 24919efe141SMatthias Springer // yield values are equivalent to each other. 25019efe141SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op); 25119efe141SMatthias Springer SmallVector<OpOperand *> yieldValues = 25219efe141SMatthias Springer bufferizableOp.getAliasingOpOperand(opResult, state); 25319efe141SMatthias Springer assert(yieldValues.size() == 2 && "expected 2 yield values"); 25419efe141SMatthias Springer bool equivalentYields = state.areEquivalentBufferizedValues( 25519efe141SMatthias Springer yieldValues[0]->get(), yieldValues[1]->get()); 25619efe141SMatthias Springer return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 25719efe141SMatthias Springer } 25819efe141SMatthias Springer }; 25919efe141SMatthias Springer 26019efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on 26119efe141SMatthias Springer /// memrefs. 26219efe141SMatthias Springer struct ForOpInterface 26319efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ForOpInterface, 26419efe141SMatthias Springer scf::ForOp> { 26519efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 26619efe141SMatthias Springer const BufferizationState &state) const { 26719efe141SMatthias Springer // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 26819efe141SMatthias Springer // its matching bbArg may. 26919efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 27019efe141SMatthias Springer return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 27119efe141SMatthias Springer } 27219efe141SMatthias Springer 27319efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 27419efe141SMatthias Springer const BufferizationState &state) const { 27519efe141SMatthias Springer // Tensor iter_args of scf::ForOps are always considered as a write. This is 27619efe141SMatthias Springer // to simplify the analysis. 27719efe141SMatthias Springer // TODO: Consider doing sth. like isValueWritten. 27819efe141SMatthias Springer return true; 27919efe141SMatthias Springer } 28019efe141SMatthias Springer 281585a8a32SMatthias Springer SmallVector<OpResult> 282585a8a32SMatthias Springer getAliasingOpResult(Operation *op, OpOperand &opOperand, 28319efe141SMatthias Springer const BufferizationState &state) const { 28419efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 28519efe141SMatthias Springer if (!opOperand.get().getType().isa<RankedTensorType>()) 286585a8a32SMatthias Springer return {}; 287585a8a32SMatthias Springer return {forOp.getResultForOpOperand(opOperand)}; 28819efe141SMatthias Springer } 28919efe141SMatthias Springer 29019efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 29119efe141SMatthias Springer const BufferizationState &state) const { 29219efe141SMatthias Springer // ForOp results are equivalent to their corresponding init_args if the 29319efe141SMatthias Springer // corresponding iter_args and yield values are equivalent. 29419efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 29519efe141SMatthias Springer OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 29619efe141SMatthias Springer auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 29719efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back()); 29819efe141SMatthias Springer bool equivalentYield = state.areEquivalentBufferizedValues( 29919efe141SMatthias Springer bbArg, yieldOp->getOperand(opResult.getResultNumber())); 30019efe141SMatthias Springer return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 30119efe141SMatthias Springer } 30219efe141SMatthias Springer 30319efe141SMatthias Springer bool isWritable(Operation *op, Value value, 30419efe141SMatthias Springer const BufferizationState &state) const { 30519efe141SMatthias Springer // Interestingly, scf::ForOp's bbArg can **always** be viewed 30619efe141SMatthias Springer // inplace from the perspective of ops nested under: 30719efe141SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an 30819efe141SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable. 30919efe141SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just 31019efe141SMatthias Springer // bufferizes to that too. 31119efe141SMatthias Springer return true; 31219efe141SMatthias Springer } 31319efe141SMatthias Springer 31419efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 31519efe141SMatthias Springer const BufferizationState &state) const { 31619efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 31719efe141SMatthias Springer Block *oldLoopBody = &forOp.getLoopBody().front(); 31819efe141SMatthias Springer 31919efe141SMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that 32019efe141SMatthias Springer // are bufferized. 32119efe141SMatthias Springer DenseSet<int64_t> indices; 32219efe141SMatthias Springer for (const auto &it : llvm::enumerate(forOp.getInitArgs())) 32319efe141SMatthias Springer if (it.value().getType().isa<TensorType>()) 32419efe141SMatthias Springer indices.insert(it.index()); 32519efe141SMatthias Springer 32619efe141SMatthias Springer // Given a range of values, apply `func` to those marked in `indices`. 32719efe141SMatthias Springer // Otherwise, store the unmodified value in the result vector. 32819efe141SMatthias Springer auto convert = [&](ValueRange values, 32919efe141SMatthias Springer llvm::function_ref<Value(Value, int64_t)> func) { 33019efe141SMatthias Springer SmallVector<Value> result; 33119efe141SMatthias Springer for (const auto &it : llvm::enumerate(values)) { 33219efe141SMatthias Springer size_t idx = it.index(); 33319efe141SMatthias Springer Value val = it.value(); 33419efe141SMatthias Springer result.push_back(indices.contains(idx) ? func(val, idx) : val); 33519efe141SMatthias Springer } 33619efe141SMatthias Springer return result; 33719efe141SMatthias Springer }; 33819efe141SMatthias Springer 33919efe141SMatthias Springer // Construct a new scf.for op with memref instead of tensor values. 34019efe141SMatthias Springer SmallVector<Value> initArgs; 34119efe141SMatthias Springer for (OpOperand &opOperand : forOp.getIterOpOperands()) { 34219efe141SMatthias Springer if (opOperand.get().getType().isa<TensorType>()) { 34319efe141SMatthias Springer FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 34419efe141SMatthias Springer if (failed(resultBuffer)) 34519efe141SMatthias Springer return failure(); 34619efe141SMatthias Springer initArgs.push_back(*resultBuffer); 34719efe141SMatthias Springer } else { 34819efe141SMatthias Springer initArgs.push_back(opOperand.get()); 34919efe141SMatthias Springer } 35019efe141SMatthias Springer } 35119efe141SMatthias Springer auto newForOp = rewriter.create<scf::ForOp>( 35219efe141SMatthias Springer forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 35319efe141SMatthias Springer forOp.getStep(), initArgs); 35419efe141SMatthias Springer Block *loopBody = &newForOp.getLoopBody().front(); 35519efe141SMatthias Springer 35619efe141SMatthias Springer // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 35719efe141SMatthias Springer // iter_args of the new loop in ToTensorOps. 35819efe141SMatthias Springer rewriter.setInsertionPointToStart(loopBody); 35919efe141SMatthias Springer SmallVector<Value> iterArgs = 36019efe141SMatthias Springer convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { 36119efe141SMatthias Springer return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 36219efe141SMatthias Springer }); 36319efe141SMatthias Springer iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 36419efe141SMatthias Springer 36519efe141SMatthias Springer // Erase terminator if present. 36619efe141SMatthias Springer if (iterArgs.size() == 1) 36719efe141SMatthias Springer rewriter.eraseOp(loopBody->getTerminator()); 36819efe141SMatthias Springer 36919efe141SMatthias Springer // Move loop body to new loop. 37019efe141SMatthias Springer rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 37119efe141SMatthias Springer 37219efe141SMatthias Springer // Update scf.yield of new loop. 37319efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 37419efe141SMatthias Springer rewriter.setInsertionPoint(yieldOp); 37519efe141SMatthias Springer SmallVector<Value> yieldValues = 37619efe141SMatthias Springer convert(yieldOp.getResults(), [&](Value val, int64_t index) { 37719efe141SMatthias Springer ensureToMemrefOpIsValid(val, initArgs[index].getType()); 37819efe141SMatthias Springer return rewriter.create<bufferization::ToMemrefOp>( 37919efe141SMatthias Springer val.getLoc(), initArgs[index].getType(), val); 38019efe141SMatthias Springer }); 38119efe141SMatthias Springer yieldOp.getResultsMutable().assign(yieldValues); 38219efe141SMatthias Springer 38319efe141SMatthias Springer // Replace loop results. 38419efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 38519efe141SMatthias Springer 38619efe141SMatthias Springer return success(); 38719efe141SMatthias Springer } 388*4ec00fb3SMatthias Springer 389*4ec00fb3SMatthias Springer /// Assert that yielded values of an scf.for op are aliasing with their 390*4ec00fb3SMatthias Springer /// corresponding bbArgs. This is required because the i-th OpResult of an 391*4ec00fb3SMatthias Springer /// scf.for op is currently assumed to alias with the i-th iter_arg (in the 392*4ec00fb3SMatthias Springer /// absence of conflicts). 393*4ec00fb3SMatthias Springer LogicalResult verifyAnalysis(Operation *op, 394*4ec00fb3SMatthias Springer const BufferizationState &state) const { 395*4ec00fb3SMatthias Springer auto forOp = cast<scf::ForOp>(op); 396*4ec00fb3SMatthias Springer auto yieldOp = 397*4ec00fb3SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 398*4ec00fb3SMatthias Springer for (OpOperand &operand : yieldOp->getOpOperands()) { 399*4ec00fb3SMatthias Springer auto tensorType = operand.get().getType().dyn_cast<TensorType>(); 400*4ec00fb3SMatthias Springer if (!tensorType) 401*4ec00fb3SMatthias Springer continue; 402*4ec00fb3SMatthias Springer 403*4ec00fb3SMatthias Springer OpOperand &forOperand = forOp.getOpOperandForResult( 404*4ec00fb3SMatthias Springer forOp->getResult(operand.getOperandNumber())); 405*4ec00fb3SMatthias Springer auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 406*4ec00fb3SMatthias Springer // Note: This is overly strict. We should check for aliasing bufferized 407*4ec00fb3SMatthias Springer // values. But we don't have a "must-alias" analysis yet. 408*4ec00fb3SMatthias Springer if (!state.areEquivalentBufferizedValues(operand.get(), bbArg)) 409*4ec00fb3SMatthias Springer // TODO: this could get resolved with copies but it can also turn into 410*4ec00fb3SMatthias Springer // swaps so we need to be careful about order of copies. 411*4ec00fb3SMatthias Springer return yieldOp->emitError() 412*4ec00fb3SMatthias Springer << "Yield operand #" << operand.getOperandNumber() 413*4ec00fb3SMatthias Springer << " does not bufferize to a buffer that is aliasing the " 414*4ec00fb3SMatthias Springer "matching" 415*4ec00fb3SMatthias Springer << " enclosing scf::for operand"; 416*4ec00fb3SMatthias Springer } 417*4ec00fb3SMatthias Springer return success(); 418*4ec00fb3SMatthias Springer } 41919efe141SMatthias Springer }; 42019efe141SMatthias Springer 42119efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 42219efe141SMatthias Springer /// this is for analysis only. 42319efe141SMatthias Springer struct YieldOpInterface 42419efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 42519efe141SMatthias Springer scf::YieldOp> { 42619efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 42719efe141SMatthias Springer const BufferizationState &state) const { 42819efe141SMatthias Springer return true; 42919efe141SMatthias Springer } 43019efe141SMatthias Springer 43119efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 43219efe141SMatthias Springer const BufferizationState &state) const { 43319efe141SMatthias Springer return false; 43419efe141SMatthias Springer } 43519efe141SMatthias Springer 436585a8a32SMatthias Springer SmallVector<OpResult> 437585a8a32SMatthias Springer getAliasingOpResult(Operation *op, OpOperand &opOperand, 43819efe141SMatthias Springer const BufferizationState &state) const { 43919efe141SMatthias Springer if (isa<scf::IfOp>(op->getParentOp())) 440585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 44119efe141SMatthias Springer if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 442585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 443585a8a32SMatthias Springer return {}; 44419efe141SMatthias Springer } 44519efe141SMatthias Springer 44619efe141SMatthias Springer bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 44719efe141SMatthias Springer const BufferizationState &state) const { 44819efe141SMatthias Springer // Yield operands always bufferize inplace. Otherwise, an alloc + copy 44919efe141SMatthias Springer // may be generated inside the block. We should not return/yield allocations 45019efe141SMatthias Springer // when possible. 45119efe141SMatthias Springer return true; 45219efe141SMatthias Springer } 45319efe141SMatthias Springer 45419efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 45519efe141SMatthias Springer const BufferizationState &state) const { 45619efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(op); 45719efe141SMatthias Springer if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>( 45819efe141SMatthias Springer yieldOp->getParentOp())) 45919efe141SMatthias Springer return yieldOp->emitError("unsupported scf::YieldOp parent"); 46019efe141SMatthias Springer return success(); 46119efe141SMatthias Springer } 46219efe141SMatthias Springer }; 46319efe141SMatthias Springer 46419efe141SMatthias Springer } // namespace 46519efe141SMatthias Springer } // namespace scf 46619efe141SMatthias Springer } // namespace mlir 46719efe141SMatthias Springer 46819efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels( 46919efe141SMatthias Springer DialectRegistry ®istry) { 47019efe141SMatthias Springer registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>(); 47119efe141SMatthias Springer registry.addOpInterface<ForOp, ForOpInterface>(); 47219efe141SMatthias Springer registry.addOpInterface<IfOp, IfOpInterface>(); 47319efe141SMatthias Springer registry.addOpInterface<YieldOp, YieldOpInterface>(); 47419efe141SMatthias Springer registry 47519efe141SMatthias Springer .addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>(); 47619efe141SMatthias Springer } 477