119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 219efe141SMatthias Springer // 319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 619efe141SMatthias Springer // 719efe141SMatthias Springer //===----------------------------------------------------------------------===// 819efe141SMatthias Springer 919efe141SMatthias Springer #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 1019efe141SMatthias Springer 1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 131e1eeae8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 1519efe141SMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1619efe141SMatthias Springer #include "mlir/IR/Dialect.h" 1719efe141SMatthias Springer #include "mlir/IR/Operation.h" 1819efe141SMatthias Springer #include "mlir/IR/PatternMatch.h" 1919efe141SMatthias Springer 2019efe141SMatthias Springer using namespace mlir; 2119efe141SMatthias Springer using namespace mlir::bufferization; 2219efe141SMatthias Springer using namespace mlir::scf; 2319efe141SMatthias Springer 2419efe141SMatthias Springer namespace mlir { 2519efe141SMatthias Springer namespace scf { 2619efe141SMatthias Springer namespace { 2719efe141SMatthias Springer 2819efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank. 2919efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 3019efe141SMatthias Springer #ifndef NDEBUG 3119efe141SMatthias Springer auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 3219efe141SMatthias Springer assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 3319efe141SMatthias Springer rankedTensorType.getRank())) && 3419efe141SMatthias Springer "to_memref would be invalid: mismatching ranks"); 3519efe141SMatthias Springer #endif 3619efe141SMatthias Springer } 3719efe141SMatthias Springer 3819efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 3919efe141SMatthias Springer /// fully implemented at the moment. 4019efe141SMatthias Springer struct ExecuteRegionOpInterface 4119efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 4219efe141SMatthias Springer scf::ExecuteRegionOp> { 4319efe141SMatthias Springer SmallVector<OpOperand *> 4419efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 459597b16aSMatthias Springer const AnalysisState &state) const { 4619efe141SMatthias Springer // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 4719efe141SMatthias Springer // any SSA value that is in scope. To allow for use-def chain traversal 4819efe141SMatthias Springer // through ExecuteRegionOps in the analysis, the corresponding yield value 4919efe141SMatthias Springer // is considered to be aliasing with the result. 5019efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 5119efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(), 5219efe141SMatthias Springer llvm::find(op->getOpResults(), opResult)); 5319efe141SMatthias Springer // TODO: Support multiple blocks. 5419efe141SMatthias Springer assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 5519efe141SMatthias Springer "expected exactly 1 block"); 5619efe141SMatthias Springer auto yieldOp = dyn_cast<scf::YieldOp>( 5719efe141SMatthias Springer executeRegionOp.getRegion().front().getTerminator()); 5819efe141SMatthias Springer assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 5919efe141SMatthias Springer return {&yieldOp->getOpOperand(resultNum)}; 6019efe141SMatthias Springer } 6119efe141SMatthias Springer 6219efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if 6319efe141SMatthias Springer // there is a memory write in the region. 6419efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult, 659597b16aSMatthias Springer const AnalysisState &state) const { 6619efe141SMatthias Springer // Similar to scf.if, results of this op are always considered memory writes 6719efe141SMatthias Springer // in the analysis. This is a useful pattern for all ops that have tensor 6819efe141SMatthias Springer // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 6919efe141SMatthias Springer // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 7019efe141SMatthias Springer // ops without OpOperands. 7119efe141SMatthias Springer return true; 7219efe141SMatthias Springer } 7319efe141SMatthias Springer 7419efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 759597b16aSMatthias Springer BufferizationState &state) const { 7619efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 7719efe141SMatthias Springer 7819efe141SMatthias Springer // Compute new result types. 7919efe141SMatthias Springer SmallVector<Type> newResultTypes; 8019efe141SMatthias Springer for (Type type : executeRegionOp->getResultTypes()) { 8119efe141SMatthias Springer if (auto tensorType = type.dyn_cast<TensorType>()) { 8219efe141SMatthias Springer newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 8319efe141SMatthias Springer } else { 8419efe141SMatthias Springer newResultTypes.push_back(type); 8519efe141SMatthias Springer } 8619efe141SMatthias Springer } 8719efe141SMatthias Springer 8819efe141SMatthias Springer // Create new op and move over region. 8919efe141SMatthias Springer auto newOp = 9019efe141SMatthias Springer rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 9119efe141SMatthias Springer newOp.getRegion().takeBody(executeRegionOp.getRegion()); 9219efe141SMatthias Springer 9319efe141SMatthias Springer // Update terminator. 9419efe141SMatthias Springer assert(newOp.getRegion().getBlocks().size() == 1 && 9519efe141SMatthias Springer "only 1 block supported"); 9619efe141SMatthias Springer Block *newBlock = &newOp.getRegion().front(); 9719efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 9819efe141SMatthias Springer rewriter.setInsertionPoint(yieldOp); 9919efe141SMatthias Springer SmallVector<Value> newYieldValues; 100bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 10119efe141SMatthias Springer Value val = it.value(); 10219efe141SMatthias Springer if (val.getType().isa<TensorType>()) { 10319efe141SMatthias Springer newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 10419efe141SMatthias Springer yieldOp.getLoc(), newResultTypes[it.index()], val)); 10519efe141SMatthias Springer } else { 10619efe141SMatthias Springer newYieldValues.push_back(val); 10719efe141SMatthias Springer } 10819efe141SMatthias Springer } 10919efe141SMatthias Springer rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 11019efe141SMatthias Springer 11119efe141SMatthias Springer // Update all uses of the old op. 11219efe141SMatthias Springer rewriter.setInsertionPointAfter(newOp); 11319efe141SMatthias Springer SmallVector<Value> newResults; 114bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 11519efe141SMatthias Springer if (it.value().isa<TensorType>()) { 11619efe141SMatthias Springer newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 11719efe141SMatthias Springer executeRegionOp.getLoc(), newOp->getResult(it.index()))); 11819efe141SMatthias Springer } else { 11919efe141SMatthias Springer newResults.push_back(newOp->getResult(it.index())); 12019efe141SMatthias Springer } 12119efe141SMatthias Springer } 12219efe141SMatthias Springer 12319efe141SMatthias Springer // Replace old op. 12419efe141SMatthias Springer rewriter.replaceOp(executeRegionOp, newResults); 12519efe141SMatthias Springer 12619efe141SMatthias Springer return success(); 12719efe141SMatthias Springer } 12819efe141SMatthias Springer 12919efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1309597b16aSMatthias Springer const AnalysisState &state) const { 13119efe141SMatthias Springer return BufferRelation::Equivalent; 13219efe141SMatthias Springer } 13319efe141SMatthias Springer }; 13419efe141SMatthias Springer 13519efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 13619efe141SMatthias Springer struct IfOpInterface 13719efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 13819efe141SMatthias Springer SmallVector<OpOperand *> 13919efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 1409597b16aSMatthias Springer const AnalysisState &state) const { 14119efe141SMatthias Springer // IfOps do not have tensor OpOperands. The yielded value can be any SSA 14219efe141SMatthias Springer // value that is in scope. To allow for use-def chain traversal through 14319efe141SMatthias Springer // IfOps in the analysis, both corresponding yield values from the then/else 14419efe141SMatthias Springer // branches are considered to be aliasing with the result. 14519efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op); 14619efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(), 14719efe141SMatthias Springer llvm::find(op->getOpResults(), opResult)); 14819efe141SMatthias Springer return {&ifOp.thenYield()->getOpOperand(resultNum), 14919efe141SMatthias Springer &ifOp.elseYield()->getOpOperand(resultNum)}; 15019efe141SMatthias Springer } 15119efe141SMatthias Springer 15219efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if 15319efe141SMatthias Springer // there is a memory write in one (or both) of the branches. Since this is not 15419efe141SMatthias Springer // allowed at the moment, we should never encounter scf.ifs that yield 15519efe141SMatthias Springer // unmodified tensors. Such scf.yield ops could just fold away. 15619efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult, 1579597b16aSMatthias Springer const AnalysisState &state) const { 15819efe141SMatthias Springer // IfOp results are always considered memory writes in the analysis. This 15919efe141SMatthias Springer // design decision simplifies the analysis considerably. E.g., consider the 16019efe141SMatthias Springer // following test case: 16119efe141SMatthias Springer // 16219efe141SMatthias Springer // %0 = "some_writing_op" : tensor<?xf32> 16319efe141SMatthias Springer // %r = scf.if %c -> (tensor<?xf32>) { 16419efe141SMatthias Springer // scf.yield %0 16519efe141SMatthias Springer // } else { 16619efe141SMatthias Springer // %1 = "another_writing_op"(%0) : tensor<?xf32> 16719efe141SMatthias Springer // } 16819efe141SMatthias Springer // "some_reading_op"(%r) 16919efe141SMatthias Springer // 17019efe141SMatthias Springer // "another_writing_op" in the above example should be able to bufferize 17119efe141SMatthias Springer // inplace in the absence of another read of %0. However, if the scf.if op 17219efe141SMatthias Springer // would not be considered a "write", the analysis would detect the 17319efe141SMatthias Springer // following conflict: 17419efe141SMatthias Springer // 17519efe141SMatthias Springer // * read = some_reading_op 17619efe141SMatthias Springer // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 17719efe141SMatthias Springer // * conflictingWrite = %1 17819efe141SMatthias Springer // 17919efe141SMatthias Springer // For more details, check the "scf.IfOp" section of the design document. 18019efe141SMatthias Springer return true; 18119efe141SMatthias Springer } 18219efe141SMatthias Springer 18319efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1849597b16aSMatthias Springer BufferizationState &state) const { 18519efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op); 18619efe141SMatthias Springer 18719efe141SMatthias Springer // Compute new types of the bufferized scf.if op. 18819efe141SMatthias Springer SmallVector<Type> newTypes; 18919efe141SMatthias Springer for (Type returnType : ifOp->getResultTypes()) { 19019efe141SMatthias Springer if (auto tensorType = returnType.dyn_cast<TensorType>()) { 19119efe141SMatthias Springer newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 19219efe141SMatthias Springer } else { 19319efe141SMatthias Springer newTypes.push_back(returnType); 19419efe141SMatthias Springer } 19519efe141SMatthias Springer } 19619efe141SMatthias Springer 19719efe141SMatthias Springer // Create new op. 19819efe141SMatthias Springer auto newIfOp = 19919efe141SMatthias Springer rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 20019efe141SMatthias Springer /*withElseRegion=*/true); 20119efe141SMatthias Springer 20219efe141SMatthias Springer // Remove terminators. 20319efe141SMatthias Springer if (!newIfOp.thenBlock()->empty()) { 20419efe141SMatthias Springer rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 20519efe141SMatthias Springer rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 20619efe141SMatthias Springer } 20719efe141SMatthias Springer 20819efe141SMatthias Springer // Move over then/else blocks. 20919efe141SMatthias Springer rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 21019efe141SMatthias Springer rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 21119efe141SMatthias Springer 21219efe141SMatthias Springer // Update scf.yield of new then-block. 21319efe141SMatthias Springer auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 21419efe141SMatthias Springer rewriter.setInsertionPoint(thenYieldOp); 21519efe141SMatthias Springer SmallVector<Value> thenYieldValues; 21619efe141SMatthias Springer for (OpOperand &operand : thenYieldOp->getOpOperands()) { 21719efe141SMatthias Springer if (operand.get().getType().isa<TensorType>()) { 21819efe141SMatthias Springer ensureToMemrefOpIsValid(operand.get(), 21919efe141SMatthias Springer newTypes[operand.getOperandNumber()]); 22019efe141SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 22119efe141SMatthias Springer operand.get().getLoc(), newTypes[operand.getOperandNumber()], 22219efe141SMatthias Springer operand.get()); 22319efe141SMatthias Springer operand.set(toMemrefOp); 22419efe141SMatthias Springer } 22519efe141SMatthias Springer } 22619efe141SMatthias Springer 22719efe141SMatthias Springer // Update scf.yield of new else-block. 22819efe141SMatthias Springer auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 22919efe141SMatthias Springer rewriter.setInsertionPoint(elseYieldOp); 23019efe141SMatthias Springer SmallVector<Value> elseYieldValues; 23119efe141SMatthias Springer for (OpOperand &operand : elseYieldOp->getOpOperands()) { 23219efe141SMatthias Springer if (operand.get().getType().isa<TensorType>()) { 23319efe141SMatthias Springer ensureToMemrefOpIsValid(operand.get(), 23419efe141SMatthias Springer newTypes[operand.getOperandNumber()]); 23519efe141SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 23619efe141SMatthias Springer operand.get().getLoc(), newTypes[operand.getOperandNumber()], 23719efe141SMatthias Springer operand.get()); 23819efe141SMatthias Springer operand.set(toMemrefOp); 23919efe141SMatthias Springer } 24019efe141SMatthias Springer } 24119efe141SMatthias Springer 24219efe141SMatthias Springer // Replace op results. 24319efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 24419efe141SMatthias Springer 24519efe141SMatthias Springer return success(); 24619efe141SMatthias Springer } 24719efe141SMatthias Springer 24819efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2499597b16aSMatthias Springer const AnalysisState &state) const { 25019efe141SMatthias Springer // IfOp results are equivalent to their corresponding yield values if both 25119efe141SMatthias Springer // yield values are equivalent to each other. 25219efe141SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op); 25319efe141SMatthias Springer SmallVector<OpOperand *> yieldValues = 25419efe141SMatthias Springer bufferizableOp.getAliasingOpOperand(opResult, state); 25519efe141SMatthias Springer assert(yieldValues.size() == 2 && "expected 2 yield values"); 25619efe141SMatthias Springer bool equivalentYields = state.areEquivalentBufferizedValues( 25719efe141SMatthias Springer yieldValues[0]->get(), yieldValues[1]->get()); 25819efe141SMatthias Springer return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 25919efe141SMatthias Springer } 26019efe141SMatthias Springer }; 26119efe141SMatthias Springer 262417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all values 263417e1c7dSMatthias Springer /// that have a tensor type. 264417e1c7dSMatthias Springer static DenseSet<int64_t> getTensorIndices(ValueRange values) { 265417e1c7dSMatthias Springer DenseSet<int64_t> result; 266417e1c7dSMatthias Springer for (const auto &it : llvm::enumerate(values)) 267417e1c7dSMatthias Springer if (it.value().getType().isa<TensorType>()) 268417e1c7dSMatthias Springer result.insert(it.index()); 269417e1c7dSMatthias Springer return result; 270417e1c7dSMatthias Springer } 271417e1c7dSMatthias Springer 272417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all 273417e1c7dSMatthias Springer /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 274a5d09c63SMatthias Springer DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 275417e1c7dSMatthias Springer ValueRange yieldedValues, 276417e1c7dSMatthias Springer const AnalysisState &state) { 277417e1c7dSMatthias Springer DenseSet<int64_t> result; 278417e1c7dSMatthias Springer int64_t counter = 0; 279417e1c7dSMatthias Springer for (const auto &it : llvm::zip(bbArgs, yieldedValues)) { 280417e1c7dSMatthias Springer if (!std::get<0>(it).getType().isa<TensorType>()) 281417e1c7dSMatthias Springer continue; 282417e1c7dSMatthias Springer if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it))) 283417e1c7dSMatthias Springer result.insert(counter); 284417e1c7dSMatthias Springer counter++; 285417e1c7dSMatthias Springer } 286417e1c7dSMatthias Springer return result; 287417e1c7dSMatthias Springer } 288417e1c7dSMatthias Springer 289417e1c7dSMatthias Springer /// Helper function for loop bufferization. Cast the given buffer to the given 290417e1c7dSMatthias Springer /// memref type. 291417e1c7dSMatthias Springer static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 292417e1c7dSMatthias Springer assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 293417e1c7dSMatthias Springer assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 294417e1c7dSMatthias Springer // If the buffer already has the correct type, no cast is needed. 295417e1c7dSMatthias Springer if (buffer.getType() == type) 296417e1c7dSMatthias Springer return buffer; 297417e1c7dSMatthias Springer // TODO: In case `type` has a layout map that is not the fully dynamic 298417e1c7dSMatthias Springer // one, we may not be able to cast the buffer. In that case, the loop 299417e1c7dSMatthias Springer // iter_arg's layout map must be changed (see uses of `castBuffer`). 300417e1c7dSMatthias Springer assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 301417e1c7dSMatthias Springer "scf.while op bufferization: cast incompatible"); 302417e1c7dSMatthias Springer return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 303417e1c7dSMatthias Springer } 304417e1c7dSMatthias Springer 305417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the bufferized values of the 306417e1c7dSMatthias Springer /// given OpOperands. If an operand is not a tensor, return the original value. 307417e1c7dSMatthias Springer static SmallVector<Value> getBuffers(RewriterBase &rewriter, 308417e1c7dSMatthias Springer MutableArrayRef<OpOperand> operands, 309417e1c7dSMatthias Springer BufferizationState &state) { 310417e1c7dSMatthias Springer SmallVector<Value> result; 311417e1c7dSMatthias Springer for (OpOperand &opOperand : operands) { 312417e1c7dSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) { 313417e1c7dSMatthias Springer FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 314417e1c7dSMatthias Springer if (failed(resultBuffer)) 315417e1c7dSMatthias Springer return {}; 316417e1c7dSMatthias Springer result.push_back(*resultBuffer); 317417e1c7dSMatthias Springer } else { 318417e1c7dSMatthias Springer result.push_back(opOperand.get()); 319417e1c7dSMatthias Springer } 320417e1c7dSMatthias Springer } 321417e1c7dSMatthias Springer return result; 322417e1c7dSMatthias Springer } 323417e1c7dSMatthias Springer 324417e1c7dSMatthias Springer /// Helper function for loop bufferization. Compute the buffer that should be 325417e1c7dSMatthias Springer /// yielded from a loop block (loop body or loop condition). If the given tensor 326417e1c7dSMatthias Springer /// is equivalent to the corresponding block argument (as indicated by 327417e1c7dSMatthias Springer /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer 328417e1c7dSMatthias Springer /// copy must be yielded. 329417e1c7dSMatthias Springer /// 330417e1c7dSMatthias Springer /// According to the `BufferizableOpInterface` implementation of scf loops, a 331417e1c7dSMatthias Springer /// a bufferized OpResult may alias only with the corresponding bufferized 332417e1c7dSMatthias Springer /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with 333417e1c7dSMatthias Springer /// the i-th init_arg; but not with any other OpOperand. If a corresponding 334417e1c7dSMatthias Springer /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by 335417e1c7dSMatthias Springer /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we 336417e1c7dSMatthias Springer /// cannot be sure and must yield a new buffer copy. (New buffer copies do not 337417e1c7dSMatthias Springer /// alias with any buffer.) 338417e1c7dSMatthias Springer static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, 339417e1c7dSMatthias Springer BaseMemRefType type, bool isEquivalent, 340417e1c7dSMatthias Springer BufferizationState &state) { 341417e1c7dSMatthias Springer assert(tensor.getType().isa<TensorType>() && "expected tensor"); 342417e1c7dSMatthias Springer ensureToMemrefOpIsValid(tensor, type); 343417e1c7dSMatthias Springer Value yieldedVal = 344417e1c7dSMatthias Springer bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); 345417e1c7dSMatthias Springer 346417e1c7dSMatthias Springer if (isEquivalent) 347417e1c7dSMatthias Springer // Yielded value is equivalent to the corresponding iter_arg bbArg. 348417e1c7dSMatthias Springer // Yield the value directly. Most IR should be like that. Everything 349417e1c7dSMatthias Springer // else must be resolved with copies and is potentially inefficient. 350417e1c7dSMatthias Springer // By default, such problematic IR would already have been rejected 351417e1c7dSMatthias Springer // during `verifyAnalysis`, unless `allow-return-allocs`. 352417e1c7dSMatthias Springer return castBuffer(rewriter, yieldedVal, type); 353417e1c7dSMatthias Springer 354417e1c7dSMatthias Springer // It is not certain that the yielded value and the iter_arg bbArg 355417e1c7dSMatthias Springer // have the same buffer. Allocate a new buffer and copy. The yielded 356417e1c7dSMatthias Springer // buffer will get deallocated by `deallocateBuffers`. 357417e1c7dSMatthias Springer 358417e1c7dSMatthias Springer // TODO: There are cases in which it is not neccessary to return a new 359417e1c7dSMatthias Springer // buffer allocation. E.g., when equivalent values are yielded in a 360417e1c7dSMatthias Springer // different order. This could be resolved with copies. 361417e1c7dSMatthias Springer Optional<Value> yieldedAlloc = state.createAlloc( 362417e1c7dSMatthias Springer rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); 363417e1c7dSMatthias Springer // TODO: We should rollback, but for now just assume that this always 364417e1c7dSMatthias Springer // succeeds. 365417e1c7dSMatthias Springer assert(yieldedAlloc.hasValue() && "could not create alloc"); 366*248e113eSMatthias Springer LogicalResult copyStatus = state.getOptions().createMemCpy( 367*248e113eSMatthias Springer rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); 368417e1c7dSMatthias Springer (void)copyStatus; 369417e1c7dSMatthias Springer assert(succeeded(copyStatus) && "could not create memcpy"); 370417e1c7dSMatthias Springer 371417e1c7dSMatthias Springer // The iter_arg memref type may have a layout map. Cast the new buffer 372417e1c7dSMatthias Springer // to the same type if needed. 373417e1c7dSMatthias Springer return castBuffer(rewriter, *yieldedAlloc, type); 374417e1c7dSMatthias Springer } 375417e1c7dSMatthias Springer 376417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a range of values, apply 377417e1c7dSMatthias Springer /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 378417e1c7dSMatthias Springer /// value in the result vector. 379417e1c7dSMatthias Springer static SmallVector<Value> 380417e1c7dSMatthias Springer convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 381417e1c7dSMatthias Springer llvm::function_ref<Value(Value, int64_t)> func) { 382417e1c7dSMatthias Springer SmallVector<Value> result; 383417e1c7dSMatthias Springer for (const auto &it : llvm::enumerate(values)) { 384417e1c7dSMatthias Springer size_t idx = it.index(); 385417e1c7dSMatthias Springer Value val = it.value(); 386417e1c7dSMatthias Springer result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); 387417e1c7dSMatthias Springer } 388417e1c7dSMatthias Springer return result; 389417e1c7dSMatthias Springer } 390417e1c7dSMatthias Springer 391417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a list of pre-bufferization 392417e1c7dSMatthias Springer /// yielded values, compute the list of bufferized yielded values. 393417e1c7dSMatthias Springer SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values, 394417e1c7dSMatthias Springer TypeRange bufferizedTypes, 395417e1c7dSMatthias Springer const DenseSet<int64_t> &tensorIndices, 396417e1c7dSMatthias Springer const DenseSet<int64_t> &equivalentTensors, 397417e1c7dSMatthias Springer BufferizationState &state) { 398417e1c7dSMatthias Springer return convertTensorValues( 399417e1c7dSMatthias Springer values, tensorIndices, [&](Value val, int64_t index) { 400417e1c7dSMatthias Springer return getYieldedBuffer(rewriter, val, 401417e1c7dSMatthias Springer bufferizedTypes[index].cast<BaseMemRefType>(), 402417e1c7dSMatthias Springer equivalentTensors.contains(index), state); 403417e1c7dSMatthias Springer }); 404417e1c7dSMatthias Springer } 405417e1c7dSMatthias Springer 406a5d09c63SMatthias Springer /// Helper function for loop bufferization. Given a list of bbArgs of the new 407a5d09c63SMatthias Springer /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 408a5d09c63SMatthias Springer /// ToTensorOps, so that the block body can be moved over to the new op. 409a5d09c63SMatthias Springer SmallVector<Value> 410a5d09c63SMatthias Springer getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 411a5d09c63SMatthias Springer const DenseSet<int64_t> &tensorIndices) { 412a5d09c63SMatthias Springer return convertTensorValues( 413a5d09c63SMatthias Springer bbArgs, tensorIndices, [&](Value val, int64_t index) { 414a5d09c63SMatthias Springer return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 415a5d09c63SMatthias Springer }); 416a5d09c63SMatthias Springer } 417a5d09c63SMatthias Springer 41819efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on 41919efe141SMatthias Springer /// memrefs. 42019efe141SMatthias Springer struct ForOpInterface 42119efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ForOpInterface, 42219efe141SMatthias Springer scf::ForOp> { 42319efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 4249597b16aSMatthias Springer const AnalysisState &state) const { 42519efe141SMatthias Springer // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 42619efe141SMatthias Springer // its matching bbArg may. 42719efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 42819efe141SMatthias Springer return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 42919efe141SMatthias Springer } 43019efe141SMatthias Springer 43119efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 4329597b16aSMatthias Springer const AnalysisState &state) const { 4331e1eeae8SMatthias Springer // Tensor iter_args of scf::ForOps are always considered as a write. 43419efe141SMatthias Springer return true; 43519efe141SMatthias Springer } 43619efe141SMatthias Springer 4379597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 4389597b16aSMatthias Springer const AnalysisState &state) const { 43919efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 440585a8a32SMatthias Springer return {forOp.getResultForOpOperand(opOperand)}; 44119efe141SMatthias Springer } 44219efe141SMatthias Springer 44319efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 4449597b16aSMatthias Springer const AnalysisState &state) const { 44519efe141SMatthias Springer // ForOp results are equivalent to their corresponding init_args if the 44619efe141SMatthias Springer // corresponding iter_args and yield values are equivalent. 44719efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 44819efe141SMatthias Springer OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 44919efe141SMatthias Springer auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 4501e1eeae8SMatthias Springer auto yieldOp = 4511e1eeae8SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 45219efe141SMatthias Springer bool equivalentYield = state.areEquivalentBufferizedValues( 45319efe141SMatthias Springer bbArg, yieldOp->getOperand(opResult.getResultNumber())); 45419efe141SMatthias Springer return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 45519efe141SMatthias Springer } 45619efe141SMatthias Springer 45719efe141SMatthias Springer bool isWritable(Operation *op, Value value, 4589597b16aSMatthias Springer const AnalysisState &state) const { 45919efe141SMatthias Springer // Interestingly, scf::ForOp's bbArg can **always** be viewed 46019efe141SMatthias Springer // inplace from the perspective of ops nested under: 46119efe141SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an 46219efe141SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable. 46319efe141SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just 46419efe141SMatthias Springer // bufferizes to that too. 46519efe141SMatthias Springer return true; 46619efe141SMatthias Springer } 46719efe141SMatthias Springer 46819efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 4699597b16aSMatthias Springer BufferizationState &state) const { 47019efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 471417e1c7dSMatthias Springer auto oldYieldOp = 472417e1c7dSMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 47319efe141SMatthias Springer Block *oldLoopBody = &forOp.getLoopBody().front(); 47419efe141SMatthias Springer 47519efe141SMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that 47619efe141SMatthias Springer // are bufferized. 477417e1c7dSMatthias Springer DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 4781e1eeae8SMatthias Springer // For every yielded value, is the value equivalent to its corresponding 4791e1eeae8SMatthias Springer // bbArg? 480417e1c7dSMatthias Springer DenseSet<int64_t> equivalentYields = 481417e1c7dSMatthias Springer getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), 482417e1c7dSMatthias Springer state.getAnalysisState()); 48319efe141SMatthias Springer 484417e1c7dSMatthias Springer // The new memref init_args of the loop. 485417e1c7dSMatthias Springer SmallVector<Value> initArgs = 486417e1c7dSMatthias Springer getBuffers(rewriter, forOp.getIterOpOperands(), state); 487417e1c7dSMatthias Springer if (initArgs.size() != indices.size()) 488417e1c7dSMatthias Springer return failure(); 48919efe141SMatthias Springer 49019efe141SMatthias Springer // Construct a new scf.for op with memref instead of tensor values. 49119efe141SMatthias Springer auto newForOp = rewriter.create<scf::ForOp>( 49219efe141SMatthias Springer forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 49319efe141SMatthias Springer forOp.getStep(), initArgs); 494417e1c7dSMatthias Springer ValueRange initArgsRange(initArgs); 495417e1c7dSMatthias Springer TypeRange initArgsTypes(initArgsRange); 49619efe141SMatthias Springer Block *loopBody = &newForOp.getLoopBody().front(); 49719efe141SMatthias Springer 49819efe141SMatthias Springer // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 49919efe141SMatthias Springer // iter_args of the new loop in ToTensorOps. 50019efe141SMatthias Springer rewriter.setInsertionPointToStart(loopBody); 501a5d09c63SMatthias Springer SmallVector<Value> iterArgs = 502a5d09c63SMatthias Springer getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 50319efe141SMatthias Springer iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 50419efe141SMatthias Springer 50519efe141SMatthias Springer // Erase terminator if present. 50619efe141SMatthias Springer if (iterArgs.size() == 1) 50719efe141SMatthias Springer rewriter.eraseOp(loopBody->getTerminator()); 50819efe141SMatthias Springer 50919efe141SMatthias Springer // Move loop body to new loop. 51019efe141SMatthias Springer rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 51119efe141SMatthias Springer 51219efe141SMatthias Springer // Update scf.yield of new loop. 51319efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 51419efe141SMatthias Springer rewriter.setInsertionPoint(yieldOp); 51519efe141SMatthias Springer SmallVector<Value> yieldValues = 516417e1c7dSMatthias Springer getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, 517417e1c7dSMatthias Springer equivalentYields, state); 51819efe141SMatthias Springer yieldOp.getResultsMutable().assign(yieldValues); 51919efe141SMatthias Springer 52019efe141SMatthias Springer // Replace loop results. 52119efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 52219efe141SMatthias Springer 52319efe141SMatthias Springer return success(); 52419efe141SMatthias Springer } 5254ec00fb3SMatthias Springer 5261e1eeae8SMatthias Springer /// Assert that yielded values of an scf.for op are equivalent to their 527f178c386SMatthias Springer /// corresponding bbArgs. In that case, the buffer relations of the 528f178c386SMatthias Springer /// corresponding OpResults are "Equivalent". 529f178c386SMatthias Springer /// 530f178c386SMatthias Springer /// If this is not the case, an allocs+copies are inserted and yielded from 531f178c386SMatthias Springer /// the loop. This could be a performance problem, so it must be explicitly 532f178c386SMatthias Springer /// activated with `alloc-return-allocs`. 5334ec00fb3SMatthias Springer LogicalResult verifyAnalysis(Operation *op, 5349597b16aSMatthias Springer const AnalysisState &state) const { 5351e1eeae8SMatthias Springer const auto &options = 5361e1eeae8SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 5371e1eeae8SMatthias Springer if (options.allowReturnAllocs) 5381e1eeae8SMatthias Springer return success(); 5391e1eeae8SMatthias Springer 5404ec00fb3SMatthias Springer auto forOp = cast<scf::ForOp>(op); 5414ec00fb3SMatthias Springer auto yieldOp = 5424ec00fb3SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 543f178c386SMatthias Springer for (OpResult opResult : op->getOpResults()) { 544f178c386SMatthias Springer if (!opResult.getType().isa<TensorType>()) 5454ec00fb3SMatthias Springer continue; 5464ec00fb3SMatthias Springer 5474ec00fb3SMatthias Springer // Note: This is overly strict. We should check for aliasing bufferized 5484ec00fb3SMatthias Springer // values. But we don't have a "must-alias" analysis yet. 549f178c386SMatthias Springer if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 5504ec00fb3SMatthias Springer return yieldOp->emitError() 551f178c386SMatthias Springer << "Yield operand #" << opResult.getResultNumber() 552e3006825SMatthias Springer << " is not equivalent to the corresponding iter bbArg"; 5534ec00fb3SMatthias Springer } 554f178c386SMatthias Springer 5554ec00fb3SMatthias Springer return success(); 5564ec00fb3SMatthias Springer } 55719efe141SMatthias Springer }; 55819efe141SMatthias Springer 559a5d09c63SMatthias Springer /// Bufferization of scf.while. Replace with a new scf.while that operates on 560a5d09c63SMatthias Springer /// memrefs. 561a5d09c63SMatthias Springer struct WhileOpInterface 562a5d09c63SMatthias Springer : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 563a5d09c63SMatthias Springer scf::WhileOp> { 564a5d09c63SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 565a5d09c63SMatthias Springer const AnalysisState &state) const { 566a5d09c63SMatthias Springer // Tensor iter_args of scf::WhileOps are always considered as a read. 567a5d09c63SMatthias Springer return true; 568a5d09c63SMatthias Springer } 569a5d09c63SMatthias Springer 570a5d09c63SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 571a5d09c63SMatthias Springer const AnalysisState &state) const { 572a5d09c63SMatthias Springer // Tensor iter_args of scf::WhileOps are always considered as a write. 573a5d09c63SMatthias Springer return true; 574a5d09c63SMatthias Springer } 575a5d09c63SMatthias Springer 576a5d09c63SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 577a5d09c63SMatthias Springer const AnalysisState &state) const { 578a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 579a5d09c63SMatthias Springer return {whileOp->getResult(opOperand.getOperandNumber())}; 580a5d09c63SMatthias Springer } 581a5d09c63SMatthias Springer 582a5d09c63SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 583a5d09c63SMatthias Springer const AnalysisState &state) const { 584a5d09c63SMatthias Springer // WhileOp results are equivalent to their corresponding init_args if the 585a5d09c63SMatthias Springer // corresponding iter_args and yield values are equivalent (for both the 586a5d09c63SMatthias Springer // "before" and the "after" block). 587a5d09c63SMatthias Springer unsigned int resultNumber = opResult.getResultNumber(); 588a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 589a5d09c63SMatthias Springer 590a5d09c63SMatthias Springer auto conditionOp = whileOp.getConditionOp(); 591a5d09c63SMatthias Springer BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 592a5d09c63SMatthias Springer Value conditionOperand = conditionOp.getArgs()[resultNumber]; 593a5d09c63SMatthias Springer bool equivCondition = 594a5d09c63SMatthias Springer state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 595a5d09c63SMatthias Springer 596a5d09c63SMatthias Springer auto yieldOp = whileOp.getYieldOp(); 597a5d09c63SMatthias Springer BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 598a5d09c63SMatthias Springer Value yieldOperand = yieldOp.getOperand(resultNumber); 599a5d09c63SMatthias Springer bool equivYield = 600a5d09c63SMatthias Springer state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 601a5d09c63SMatthias Springer 602a5d09c63SMatthias Springer return equivCondition && equivYield ? BufferRelation::Equivalent 603a5d09c63SMatthias Springer : BufferRelation::None; 604a5d09c63SMatthias Springer } 605a5d09c63SMatthias Springer 606a5d09c63SMatthias Springer bool isWritable(Operation *op, Value value, 607a5d09c63SMatthias Springer const AnalysisState &state) const { 608a5d09c63SMatthias Springer // Interestingly, scf::WhileOp's bbArg can **always** be viewed 609a5d09c63SMatthias Springer // inplace from the perspective of ops nested under: 610a5d09c63SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an 611a5d09c63SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable. 612a5d09c63SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just 613a5d09c63SMatthias Springer // bufferizes to that too. 614a5d09c63SMatthias Springer return true; 615a5d09c63SMatthias Springer } 616a5d09c63SMatthias Springer 617a5d09c63SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 618a5d09c63SMatthias Springer BufferizationState &state) const { 619a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 620a5d09c63SMatthias Springer 621a5d09c63SMatthias Springer assert(whileOp.getBefore().getBlocks().size() == 1 && 622a5d09c63SMatthias Springer "regions with multiple blocks not supported"); 623a5d09c63SMatthias Springer Block *beforeBody = &whileOp.getBefore().front(); 624a5d09c63SMatthias Springer assert(whileOp.getAfter().getBlocks().size() == 1 && 625a5d09c63SMatthias Springer "regions with multiple blocks not supported"); 626a5d09c63SMatthias Springer Block *afterBody = &whileOp.getAfter().front(); 627a5d09c63SMatthias Springer 628a5d09c63SMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that 629a5d09c63SMatthias Springer // are bufferized. 630a5d09c63SMatthias Springer DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits()); 631a5d09c63SMatthias Springer // For every yielded value, is the value equivalent to its corresponding 632a5d09c63SMatthias Springer // bbArg? 633a5d09c63SMatthias Springer DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 634a5d09c63SMatthias Springer whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), 635a5d09c63SMatthias Springer state.getAnalysisState()); 636a5d09c63SMatthias Springer DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 637a5d09c63SMatthias Springer whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), 638a5d09c63SMatthias Springer state.getAnalysisState()); 639a5d09c63SMatthias Springer 640a5d09c63SMatthias Springer // The new memref init_args of the loop. 641a5d09c63SMatthias Springer SmallVector<Value> initArgs = 642a5d09c63SMatthias Springer getBuffers(rewriter, whileOp->getOpOperands(), state); 643a5d09c63SMatthias Springer if (initArgs.size() != indices.size()) 644a5d09c63SMatthias Springer return failure(); 645a5d09c63SMatthias Springer 646a5d09c63SMatthias Springer // Construct a new scf.while op with memref instead of tensor values. 647a5d09c63SMatthias Springer ValueRange argsRange(initArgs); 648a5d09c63SMatthias Springer TypeRange argsTypes(argsRange); 649a5d09c63SMatthias Springer auto newWhileOp = 650a5d09c63SMatthias Springer rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs); 651a5d09c63SMatthias Springer // Add before/after regions to the new op. 652a5d09c63SMatthias Springer SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc()); 653a5d09c63SMatthias Springer Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 654a5d09c63SMatthias Springer newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs); 655a5d09c63SMatthias Springer Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 656a5d09c63SMatthias Springer newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs); 657a5d09c63SMatthias Springer 658a5d09c63SMatthias Springer // Set up new iter_args and move the loop condition block to the new op. 659a5d09c63SMatthias Springer // The old block uses tensors, so wrap the (memref) bbArgs of the new block 660a5d09c63SMatthias Springer // in ToTensorOps. 661a5d09c63SMatthias Springer rewriter.setInsertionPointToStart(newBeforeBody); 662a5d09c63SMatthias Springer SmallVector<Value> newBeforeArgs = getBbArgReplacements( 663a5d09c63SMatthias Springer rewriter, newWhileOp.getBeforeArguments(), indices); 664a5d09c63SMatthias Springer rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 665a5d09c63SMatthias Springer 666a5d09c63SMatthias Springer // Update scf.condition of new loop. 667a5d09c63SMatthias Springer auto newConditionOp = newWhileOp.getConditionOp(); 668a5d09c63SMatthias Springer rewriter.setInsertionPoint(newConditionOp); 669a5d09c63SMatthias Springer SmallVector<Value> newConditionArgs = 670a5d09c63SMatthias Springer getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices, 671a5d09c63SMatthias Springer equivalentYieldsBefore, state); 672a5d09c63SMatthias Springer newConditionOp.getArgsMutable().assign(newConditionArgs); 673a5d09c63SMatthias Springer 674a5d09c63SMatthias Springer // Set up new iter_args and move the loop body block to the new op. 675a5d09c63SMatthias Springer // The old block uses tensors, so wrap the (memref) bbArgs of the new block 676a5d09c63SMatthias Springer // in ToTensorOps. 677a5d09c63SMatthias Springer rewriter.setInsertionPointToStart(newAfterBody); 678a5d09c63SMatthias Springer SmallVector<Value> newAfterArgs = 679a5d09c63SMatthias Springer getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices); 680a5d09c63SMatthias Springer rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 681a5d09c63SMatthias Springer 682a5d09c63SMatthias Springer // Update scf.yield of the new loop. 683a5d09c63SMatthias Springer auto newYieldOp = newWhileOp.getYieldOp(); 684a5d09c63SMatthias Springer rewriter.setInsertionPoint(newYieldOp); 685a5d09c63SMatthias Springer SmallVector<Value> newYieldValues = 686a5d09c63SMatthias Springer getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices, 687a5d09c63SMatthias Springer equivalentYieldsAfter, state); 688a5d09c63SMatthias Springer newYieldOp.getResultsMutable().assign(newYieldValues); 689a5d09c63SMatthias Springer 690a5d09c63SMatthias Springer // Replace loop results. 691a5d09c63SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 692a5d09c63SMatthias Springer 693a5d09c63SMatthias Springer return success(); 694a5d09c63SMatthias Springer } 695a5d09c63SMatthias Springer 696a5d09c63SMatthias Springer /// Assert that yielded values of an scf.while op are equivalent to their 697a5d09c63SMatthias Springer /// corresponding bbArgs. In that case, the buffer relations of the 698a5d09c63SMatthias Springer /// corresponding OpResults are "Equivalent". 699a5d09c63SMatthias Springer /// 700a5d09c63SMatthias Springer /// If this is not the case, allocs+copies are inserted and yielded from 701a5d09c63SMatthias Springer /// the loop. This could be a performance problem, so it must be explicitly 702a5d09c63SMatthias Springer /// activated with `alloc-return-allocs`. 703a5d09c63SMatthias Springer /// 704a5d09c63SMatthias Springer /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 705a5d09c63SMatthias Springer /// equivalence condition must be checked for both. 706a5d09c63SMatthias Springer LogicalResult verifyAnalysis(Operation *op, 707a5d09c63SMatthias Springer const AnalysisState &state) const { 708a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 709a5d09c63SMatthias Springer const auto &options = 710a5d09c63SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 711a5d09c63SMatthias Springer if (options.allowReturnAllocs) 712a5d09c63SMatthias Springer return success(); 713a5d09c63SMatthias Springer 714a5d09c63SMatthias Springer auto conditionOp = whileOp.getConditionOp(); 715a5d09c63SMatthias Springer for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 716a5d09c63SMatthias Springer if (!it.value().getType().isa<TensorType>()) 717a5d09c63SMatthias Springer continue; 718a5d09c63SMatthias Springer if (!state.areEquivalentBufferizedValues( 719a5d09c63SMatthias Springer it.value(), conditionOp->getBlock()->getArgument(it.index()))) 720a5d09c63SMatthias Springer return conditionOp->emitError() 721a5d09c63SMatthias Springer << "Condition arg #" << it.index() 722a5d09c63SMatthias Springer << " is not equivalent to the corresponding iter bbArg"; 723a5d09c63SMatthias Springer } 724a5d09c63SMatthias Springer 725a5d09c63SMatthias Springer auto yieldOp = whileOp.getYieldOp(); 726a5d09c63SMatthias Springer for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 727a5d09c63SMatthias Springer if (!it.value().getType().isa<TensorType>()) 728a5d09c63SMatthias Springer continue; 729a5d09c63SMatthias Springer if (!state.areEquivalentBufferizedValues( 730a5d09c63SMatthias Springer it.value(), yieldOp->getBlock()->getArgument(it.index()))) 731a5d09c63SMatthias Springer return yieldOp->emitError() 732a5d09c63SMatthias Springer << "Yield operand #" << it.index() 733a5d09c63SMatthias Springer << " is not equivalent to the corresponding iter bbArg"; 734a5d09c63SMatthias Springer } 735a5d09c63SMatthias Springer 736a5d09c63SMatthias Springer return success(); 737a5d09c63SMatthias Springer } 738a5d09c63SMatthias Springer }; 739a5d09c63SMatthias Springer 74019efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 74119efe141SMatthias Springer /// this is for analysis only. 74219efe141SMatthias Springer struct YieldOpInterface 74319efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 74419efe141SMatthias Springer scf::YieldOp> { 74519efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 7469597b16aSMatthias Springer const AnalysisState &state) const { 74719efe141SMatthias Springer return true; 74819efe141SMatthias Springer } 74919efe141SMatthias Springer 75019efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 7519597b16aSMatthias Springer const AnalysisState &state) const { 75219efe141SMatthias Springer return false; 75319efe141SMatthias Springer } 75419efe141SMatthias Springer 7559597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 7569597b16aSMatthias Springer const AnalysisState &state) const { 75719efe141SMatthias Springer if (isa<scf::IfOp>(op->getParentOp())) 758585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 75919efe141SMatthias Springer if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 760585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 761585a8a32SMatthias Springer return {}; 76219efe141SMatthias Springer } 76319efe141SMatthias Springer 76419efe141SMatthias Springer bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 7659597b16aSMatthias Springer const AnalysisState &state) const { 76619efe141SMatthias Springer // Yield operands always bufferize inplace. Otherwise, an alloc + copy 76719efe141SMatthias Springer // may be generated inside the block. We should not return/yield allocations 76819efe141SMatthias Springer // when possible. 76919efe141SMatthias Springer return true; 77019efe141SMatthias Springer } 77119efe141SMatthias Springer 77219efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 7739597b16aSMatthias Springer BufferizationState &state) const { 77419efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(op); 775a5d09c63SMatthias Springer if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 77619efe141SMatthias Springer yieldOp->getParentOp())) 77719efe141SMatthias Springer return yieldOp->emitError("unsupported scf::YieldOp parent"); 77819efe141SMatthias Springer return success(); 77919efe141SMatthias Springer } 78019efe141SMatthias Springer }; 78119efe141SMatthias Springer 78219efe141SMatthias Springer } // namespace 78319efe141SMatthias Springer } // namespace scf 78419efe141SMatthias Springer } // namespace mlir 78519efe141SMatthias Springer 78619efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels( 78719efe141SMatthias Springer DialectRegistry ®istry) { 78877eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 78977eee579SRiver Riddle ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 79077eee579SRiver Riddle ForOp::attachInterface<ForOpInterface>(*ctx); 79177eee579SRiver Riddle IfOp::attachInterface<IfOpInterface>(*ctx); 792a5d09c63SMatthias Springer WhileOp::attachInterface<WhileOpInterface>(*ctx); 79377eee579SRiver Riddle YieldOp::attachInterface<YieldOpInterface>(*ctx); 79477eee579SRiver Riddle }); 79519efe141SMatthias Springer } 796