119efe141SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 219efe141SMatthias Springer // 319efe141SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 419efe141SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 519efe141SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 619efe141SMatthias Springer // 719efe141SMatthias Springer //===----------------------------------------------------------------------===// 819efe141SMatthias Springer 919efe141SMatthias Springer #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 1019efe141SMatthias Springer 1119efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1219efe141SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 131e1eeae8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 1519efe141SMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1619efe141SMatthias Springer #include "mlir/IR/Dialect.h" 1719efe141SMatthias Springer #include "mlir/IR/Operation.h" 1819efe141SMatthias Springer #include "mlir/IR/PatternMatch.h" 1919efe141SMatthias Springer 2019efe141SMatthias Springer using namespace mlir; 2119efe141SMatthias Springer using namespace mlir::bufferization; 2219efe141SMatthias Springer using namespace mlir::scf; 2319efe141SMatthias Springer 2419efe141SMatthias Springer namespace mlir { 2519efe141SMatthias Springer namespace scf { 2619efe141SMatthias Springer namespace { 2719efe141SMatthias Springer 2819efe141SMatthias Springer // bufferization.to_memref is not allowed to change the rank. 2919efe141SMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 3019efe141SMatthias Springer #ifndef NDEBUG 3119efe141SMatthias Springer auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 3219efe141SMatthias Springer assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 3319efe141SMatthias Springer rankedTensorType.getRank())) && 3419efe141SMatthias Springer "to_memref would be invalid: mismatching ranks"); 3519efe141SMatthias Springer #endif 3619efe141SMatthias Springer } 3719efe141SMatthias Springer 3819efe141SMatthias Springer /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 3919efe141SMatthias Springer /// fully implemented at the moment. 4019efe141SMatthias Springer struct ExecuteRegionOpInterface 4119efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 4219efe141SMatthias Springer scf::ExecuteRegionOp> { 4319efe141SMatthias Springer SmallVector<OpOperand *> 4419efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 459597b16aSMatthias Springer const AnalysisState &state) const { 4619efe141SMatthias Springer // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 4719efe141SMatthias Springer // any SSA value that is in scope. To allow for use-def chain traversal 4819efe141SMatthias Springer // through ExecuteRegionOps in the analysis, the corresponding yield value 4919efe141SMatthias Springer // is considered to be aliasing with the result. 5019efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 5119efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(), 5219efe141SMatthias Springer llvm::find(op->getOpResults(), opResult)); 5319efe141SMatthias Springer // TODO: Support multiple blocks. 5419efe141SMatthias Springer assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 5519efe141SMatthias Springer "expected exactly 1 block"); 5619efe141SMatthias Springer auto yieldOp = dyn_cast<scf::YieldOp>( 5719efe141SMatthias Springer executeRegionOp.getRegion().front().getTerminator()); 5819efe141SMatthias Springer assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 5919efe141SMatthias Springer return {&yieldOp->getOpOperand(resultNum)}; 6019efe141SMatthias Springer } 6119efe141SMatthias Springer 6219efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if 6319efe141SMatthias Springer // there is a memory write in the region. 6419efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult, 659597b16aSMatthias Springer const AnalysisState &state) const { 6619efe141SMatthias Springer // Similar to scf.if, results of this op are always considered memory writes 6719efe141SMatthias Springer // in the analysis. This is a useful pattern for all ops that have tensor 6819efe141SMatthias Springer // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 6919efe141SMatthias Springer // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 7019efe141SMatthias Springer // ops without OpOperands. 7119efe141SMatthias Springer return true; 7219efe141SMatthias Springer } 7319efe141SMatthias Springer 7419efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 759597b16aSMatthias Springer BufferizationState &state) const { 7619efe141SMatthias Springer auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 7719efe141SMatthias Springer 7819efe141SMatthias Springer // Compute new result types. 7919efe141SMatthias Springer SmallVector<Type> newResultTypes; 8019efe141SMatthias Springer for (Type type : executeRegionOp->getResultTypes()) { 8119efe141SMatthias Springer if (auto tensorType = type.dyn_cast<TensorType>()) { 82*12e41d92SMatthias Springer // TODO: Infer the result type instead of computing it. 8319efe141SMatthias Springer newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 8419efe141SMatthias Springer } else { 8519efe141SMatthias Springer newResultTypes.push_back(type); 8619efe141SMatthias Springer } 8719efe141SMatthias Springer } 8819efe141SMatthias Springer 8919efe141SMatthias Springer // Create new op and move over region. 9019efe141SMatthias Springer auto newOp = 9119efe141SMatthias Springer rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 9219efe141SMatthias Springer newOp.getRegion().takeBody(executeRegionOp.getRegion()); 9319efe141SMatthias Springer 9419efe141SMatthias Springer // Update terminator. 9519efe141SMatthias Springer assert(newOp.getRegion().getBlocks().size() == 1 && 9619efe141SMatthias Springer "only 1 block supported"); 9719efe141SMatthias Springer Block *newBlock = &newOp.getRegion().front(); 9819efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 9919efe141SMatthias Springer rewriter.setInsertionPoint(yieldOp); 10019efe141SMatthias Springer SmallVector<Value> newYieldValues; 101bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 10219efe141SMatthias Springer Value val = it.value(); 10319efe141SMatthias Springer if (val.getType().isa<TensorType>()) { 10419efe141SMatthias Springer newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 10519efe141SMatthias Springer yieldOp.getLoc(), newResultTypes[it.index()], val)); 10619efe141SMatthias Springer } else { 10719efe141SMatthias Springer newYieldValues.push_back(val); 10819efe141SMatthias Springer } 10919efe141SMatthias Springer } 11019efe141SMatthias Springer rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 11119efe141SMatthias Springer 11219efe141SMatthias Springer // Update all uses of the old op. 11319efe141SMatthias Springer rewriter.setInsertionPointAfter(newOp); 11419efe141SMatthias Springer SmallVector<Value> newResults; 115bb6119ebSMehdi Amini for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 11619efe141SMatthias Springer if (it.value().isa<TensorType>()) { 11719efe141SMatthias Springer newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 11819efe141SMatthias Springer executeRegionOp.getLoc(), newOp->getResult(it.index()))); 11919efe141SMatthias Springer } else { 12019efe141SMatthias Springer newResults.push_back(newOp->getResult(it.index())); 12119efe141SMatthias Springer } 12219efe141SMatthias Springer } 12319efe141SMatthias Springer 12419efe141SMatthias Springer // Replace old op. 12519efe141SMatthias Springer rewriter.replaceOp(executeRegionOp, newResults); 12619efe141SMatthias Springer 12719efe141SMatthias Springer return success(); 12819efe141SMatthias Springer } 12919efe141SMatthias Springer 13019efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1319597b16aSMatthias Springer const AnalysisState &state) const { 13219efe141SMatthias Springer return BufferRelation::Equivalent; 13319efe141SMatthias Springer } 13419efe141SMatthias Springer }; 13519efe141SMatthias Springer 13619efe141SMatthias Springer /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 13719efe141SMatthias Springer struct IfOpInterface 13819efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 13919efe141SMatthias Springer SmallVector<OpOperand *> 14019efe141SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 1419597b16aSMatthias Springer const AnalysisState &state) const { 14219efe141SMatthias Springer // IfOps do not have tensor OpOperands. The yielded value can be any SSA 14319efe141SMatthias Springer // value that is in scope. To allow for use-def chain traversal through 14419efe141SMatthias Springer // IfOps in the analysis, both corresponding yield values from the then/else 14519efe141SMatthias Springer // branches are considered to be aliasing with the result. 14619efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op); 14719efe141SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(), 14819efe141SMatthias Springer llvm::find(op->getOpResults(), opResult)); 14919efe141SMatthias Springer return {&ifOp.thenYield()->getOpOperand(resultNum), 15019efe141SMatthias Springer &ifOp.elseYield()->getOpOperand(resultNum)}; 15119efe141SMatthias Springer } 15219efe141SMatthias Springer 15319efe141SMatthias Springer // TODO: For better bufferization results, this could return `true` only if 15419efe141SMatthias Springer // there is a memory write in one (or both) of the branches. Since this is not 15519efe141SMatthias Springer // allowed at the moment, we should never encounter scf.ifs that yield 15619efe141SMatthias Springer // unmodified tensors. Such scf.yield ops could just fold away. 15719efe141SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult, 1589597b16aSMatthias Springer const AnalysisState &state) const { 15919efe141SMatthias Springer // IfOp results are always considered memory writes in the analysis. This 16019efe141SMatthias Springer // design decision simplifies the analysis considerably. E.g., consider the 16119efe141SMatthias Springer // following test case: 16219efe141SMatthias Springer // 16319efe141SMatthias Springer // %0 = "some_writing_op" : tensor<?xf32> 16419efe141SMatthias Springer // %r = scf.if %c -> (tensor<?xf32>) { 16519efe141SMatthias Springer // scf.yield %0 16619efe141SMatthias Springer // } else { 16719efe141SMatthias Springer // %1 = "another_writing_op"(%0) : tensor<?xf32> 16819efe141SMatthias Springer // } 16919efe141SMatthias Springer // "some_reading_op"(%r) 17019efe141SMatthias Springer // 17119efe141SMatthias Springer // "another_writing_op" in the above example should be able to bufferize 17219efe141SMatthias Springer // inplace in the absence of another read of %0. However, if the scf.if op 17319efe141SMatthias Springer // would not be considered a "write", the analysis would detect the 17419efe141SMatthias Springer // following conflict: 17519efe141SMatthias Springer // 17619efe141SMatthias Springer // * read = some_reading_op 17719efe141SMatthias Springer // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 17819efe141SMatthias Springer // * conflictingWrite = %1 17919efe141SMatthias Springer // 18019efe141SMatthias Springer // For more details, check the "scf.IfOp" section of the design document. 18119efe141SMatthias Springer return true; 18219efe141SMatthias Springer } 18319efe141SMatthias Springer 18419efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1859597b16aSMatthias Springer BufferizationState &state) const { 18619efe141SMatthias Springer auto ifOp = cast<scf::IfOp>(op); 18719efe141SMatthias Springer 18819efe141SMatthias Springer // Compute new types of the bufferized scf.if op. 18919efe141SMatthias Springer SmallVector<Type> newTypes; 19019efe141SMatthias Springer for (Type returnType : ifOp->getResultTypes()) { 19119efe141SMatthias Springer if (auto tensorType = returnType.dyn_cast<TensorType>()) { 192*12e41d92SMatthias Springer // TODO: Infer the result type instead of computing it. 19319efe141SMatthias Springer newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 19419efe141SMatthias Springer } else { 19519efe141SMatthias Springer newTypes.push_back(returnType); 19619efe141SMatthias Springer } 19719efe141SMatthias Springer } 19819efe141SMatthias Springer 19919efe141SMatthias Springer // Create new op. 20019efe141SMatthias Springer auto newIfOp = 20119efe141SMatthias Springer rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 20219efe141SMatthias Springer /*withElseRegion=*/true); 20319efe141SMatthias Springer 20419efe141SMatthias Springer // Remove terminators. 20519efe141SMatthias Springer if (!newIfOp.thenBlock()->empty()) { 20619efe141SMatthias Springer rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 20719efe141SMatthias Springer rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 20819efe141SMatthias Springer } 20919efe141SMatthias Springer 21019efe141SMatthias Springer // Move over then/else blocks. 21119efe141SMatthias Springer rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 21219efe141SMatthias Springer rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 21319efe141SMatthias Springer 21419efe141SMatthias Springer // Update scf.yield of new then-block. 21519efe141SMatthias Springer auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 21619efe141SMatthias Springer rewriter.setInsertionPoint(thenYieldOp); 21719efe141SMatthias Springer SmallVector<Value> thenYieldValues; 21819efe141SMatthias Springer for (OpOperand &operand : thenYieldOp->getOpOperands()) { 21919efe141SMatthias Springer if (operand.get().getType().isa<TensorType>()) { 22019efe141SMatthias Springer ensureToMemrefOpIsValid(operand.get(), 22119efe141SMatthias Springer newTypes[operand.getOperandNumber()]); 22219efe141SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 22319efe141SMatthias Springer operand.get().getLoc(), newTypes[operand.getOperandNumber()], 22419efe141SMatthias Springer operand.get()); 22519efe141SMatthias Springer operand.set(toMemrefOp); 22619efe141SMatthias Springer } 22719efe141SMatthias Springer } 22819efe141SMatthias Springer 22919efe141SMatthias Springer // Update scf.yield of new else-block. 23019efe141SMatthias Springer auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 23119efe141SMatthias Springer rewriter.setInsertionPoint(elseYieldOp); 23219efe141SMatthias Springer SmallVector<Value> elseYieldValues; 23319efe141SMatthias Springer for (OpOperand &operand : elseYieldOp->getOpOperands()) { 23419efe141SMatthias Springer if (operand.get().getType().isa<TensorType>()) { 23519efe141SMatthias Springer ensureToMemrefOpIsValid(operand.get(), 23619efe141SMatthias Springer newTypes[operand.getOperandNumber()]); 23719efe141SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 23819efe141SMatthias Springer operand.get().getLoc(), newTypes[operand.getOperandNumber()], 23919efe141SMatthias Springer operand.get()); 24019efe141SMatthias Springer operand.set(toMemrefOp); 24119efe141SMatthias Springer } 24219efe141SMatthias Springer } 24319efe141SMatthias Springer 24419efe141SMatthias Springer // Replace op results. 24519efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 24619efe141SMatthias Springer 24719efe141SMatthias Springer return success(); 24819efe141SMatthias Springer } 24919efe141SMatthias Springer 25019efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2519597b16aSMatthias Springer const AnalysisState &state) const { 25219efe141SMatthias Springer // IfOp results are equivalent to their corresponding yield values if both 25319efe141SMatthias Springer // yield values are equivalent to each other. 25419efe141SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(op); 25519efe141SMatthias Springer SmallVector<OpOperand *> yieldValues = 25619efe141SMatthias Springer bufferizableOp.getAliasingOpOperand(opResult, state); 25719efe141SMatthias Springer assert(yieldValues.size() == 2 && "expected 2 yield values"); 25819efe141SMatthias Springer bool equivalentYields = state.areEquivalentBufferizedValues( 25919efe141SMatthias Springer yieldValues[0]->get(), yieldValues[1]->get()); 26019efe141SMatthias Springer return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 26119efe141SMatthias Springer } 26219efe141SMatthias Springer }; 26319efe141SMatthias Springer 264417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all values 265417e1c7dSMatthias Springer /// that have a tensor type. 266417e1c7dSMatthias Springer static DenseSet<int64_t> getTensorIndices(ValueRange values) { 267417e1c7dSMatthias Springer DenseSet<int64_t> result; 268417e1c7dSMatthias Springer for (const auto &it : llvm::enumerate(values)) 269417e1c7dSMatthias Springer if (it.value().getType().isa<TensorType>()) 270417e1c7dSMatthias Springer result.insert(it.index()); 271417e1c7dSMatthias Springer return result; 272417e1c7dSMatthias Springer } 273417e1c7dSMatthias Springer 274417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the indices of all 275417e1c7dSMatthias Springer /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 276a5d09c63SMatthias Springer DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 277417e1c7dSMatthias Springer ValueRange yieldedValues, 278417e1c7dSMatthias Springer const AnalysisState &state) { 279417e1c7dSMatthias Springer DenseSet<int64_t> result; 280417e1c7dSMatthias Springer int64_t counter = 0; 281417e1c7dSMatthias Springer for (const auto &it : llvm::zip(bbArgs, yieldedValues)) { 282417e1c7dSMatthias Springer if (!std::get<0>(it).getType().isa<TensorType>()) 283417e1c7dSMatthias Springer continue; 284417e1c7dSMatthias Springer if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it))) 285417e1c7dSMatthias Springer result.insert(counter); 286417e1c7dSMatthias Springer counter++; 287417e1c7dSMatthias Springer } 288417e1c7dSMatthias Springer return result; 289417e1c7dSMatthias Springer } 290417e1c7dSMatthias Springer 291417e1c7dSMatthias Springer /// Helper function for loop bufferization. Cast the given buffer to the given 292417e1c7dSMatthias Springer /// memref type. 293417e1c7dSMatthias Springer static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 294417e1c7dSMatthias Springer assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 295417e1c7dSMatthias Springer assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 296417e1c7dSMatthias Springer // If the buffer already has the correct type, no cast is needed. 297417e1c7dSMatthias Springer if (buffer.getType() == type) 298417e1c7dSMatthias Springer return buffer; 299417e1c7dSMatthias Springer // TODO: In case `type` has a layout map that is not the fully dynamic 300417e1c7dSMatthias Springer // one, we may not be able to cast the buffer. In that case, the loop 301417e1c7dSMatthias Springer // iter_arg's layout map must be changed (see uses of `castBuffer`). 302417e1c7dSMatthias Springer assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 303417e1c7dSMatthias Springer "scf.while op bufferization: cast incompatible"); 304417e1c7dSMatthias Springer return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 305417e1c7dSMatthias Springer } 306417e1c7dSMatthias Springer 307417e1c7dSMatthias Springer /// Helper function for loop bufferization. Return the bufferized values of the 308417e1c7dSMatthias Springer /// given OpOperands. If an operand is not a tensor, return the original value. 309417e1c7dSMatthias Springer static SmallVector<Value> getBuffers(RewriterBase &rewriter, 310417e1c7dSMatthias Springer MutableArrayRef<OpOperand> operands, 311417e1c7dSMatthias Springer BufferizationState &state) { 312417e1c7dSMatthias Springer SmallVector<Value> result; 313417e1c7dSMatthias Springer for (OpOperand &opOperand : operands) { 314417e1c7dSMatthias Springer if (opOperand.get().getType().isa<TensorType>()) { 315417e1c7dSMatthias Springer FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 316417e1c7dSMatthias Springer if (failed(resultBuffer)) 317417e1c7dSMatthias Springer return {}; 318417e1c7dSMatthias Springer result.push_back(*resultBuffer); 319417e1c7dSMatthias Springer } else { 320417e1c7dSMatthias Springer result.push_back(opOperand.get()); 321417e1c7dSMatthias Springer } 322417e1c7dSMatthias Springer } 323417e1c7dSMatthias Springer return result; 324417e1c7dSMatthias Springer } 325417e1c7dSMatthias Springer 326417e1c7dSMatthias Springer /// Helper function for loop bufferization. Compute the buffer that should be 327417e1c7dSMatthias Springer /// yielded from a loop block (loop body or loop condition). If the given tensor 328417e1c7dSMatthias Springer /// is equivalent to the corresponding block argument (as indicated by 329417e1c7dSMatthias Springer /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer 330417e1c7dSMatthias Springer /// copy must be yielded. 331417e1c7dSMatthias Springer /// 332417e1c7dSMatthias Springer /// According to the `BufferizableOpInterface` implementation of scf loops, a 333417e1c7dSMatthias Springer /// a bufferized OpResult may alias only with the corresponding bufferized 334417e1c7dSMatthias Springer /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with 335417e1c7dSMatthias Springer /// the i-th init_arg; but not with any other OpOperand. If a corresponding 336417e1c7dSMatthias Springer /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by 337417e1c7dSMatthias Springer /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we 338417e1c7dSMatthias Springer /// cannot be sure and must yield a new buffer copy. (New buffer copies do not 339417e1c7dSMatthias Springer /// alias with any buffer.) 340417e1c7dSMatthias Springer static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, 341417e1c7dSMatthias Springer BaseMemRefType type, bool isEquivalent, 342417e1c7dSMatthias Springer BufferizationState &state) { 343417e1c7dSMatthias Springer assert(tensor.getType().isa<TensorType>() && "expected tensor"); 344417e1c7dSMatthias Springer ensureToMemrefOpIsValid(tensor, type); 345417e1c7dSMatthias Springer Value yieldedVal = 346417e1c7dSMatthias Springer bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); 347417e1c7dSMatthias Springer 348417e1c7dSMatthias Springer if (isEquivalent) 349417e1c7dSMatthias Springer // Yielded value is equivalent to the corresponding iter_arg bbArg. 350417e1c7dSMatthias Springer // Yield the value directly. Most IR should be like that. Everything 351417e1c7dSMatthias Springer // else must be resolved with copies and is potentially inefficient. 352417e1c7dSMatthias Springer // By default, such problematic IR would already have been rejected 353417e1c7dSMatthias Springer // during `verifyAnalysis`, unless `allow-return-allocs`. 354417e1c7dSMatthias Springer return castBuffer(rewriter, yieldedVal, type); 355417e1c7dSMatthias Springer 356417e1c7dSMatthias Springer // It is not certain that the yielded value and the iter_arg bbArg 357417e1c7dSMatthias Springer // have the same buffer. Allocate a new buffer and copy. The yielded 358417e1c7dSMatthias Springer // buffer will get deallocated by `deallocateBuffers`. 359417e1c7dSMatthias Springer 360417e1c7dSMatthias Springer // TODO: There are cases in which it is not neccessary to return a new 361417e1c7dSMatthias Springer // buffer allocation. E.g., when equivalent values are yielded in a 362417e1c7dSMatthias Springer // different order. This could be resolved with copies. 363417e1c7dSMatthias Springer Optional<Value> yieldedAlloc = state.createAlloc( 364417e1c7dSMatthias Springer rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); 365417e1c7dSMatthias Springer // TODO: We should rollback, but for now just assume that this always 366417e1c7dSMatthias Springer // succeeds. 367417e1c7dSMatthias Springer assert(yieldedAlloc.hasValue() && "could not create alloc"); 368248e113eSMatthias Springer LogicalResult copyStatus = state.getOptions().createMemCpy( 369248e113eSMatthias Springer rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); 370417e1c7dSMatthias Springer (void)copyStatus; 371417e1c7dSMatthias Springer assert(succeeded(copyStatus) && "could not create memcpy"); 372417e1c7dSMatthias Springer 373417e1c7dSMatthias Springer // The iter_arg memref type may have a layout map. Cast the new buffer 374417e1c7dSMatthias Springer // to the same type if needed. 375417e1c7dSMatthias Springer return castBuffer(rewriter, *yieldedAlloc, type); 376417e1c7dSMatthias Springer } 377417e1c7dSMatthias Springer 378417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a range of values, apply 379417e1c7dSMatthias Springer /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 380417e1c7dSMatthias Springer /// value in the result vector. 381417e1c7dSMatthias Springer static SmallVector<Value> 382417e1c7dSMatthias Springer convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 383417e1c7dSMatthias Springer llvm::function_ref<Value(Value, int64_t)> func) { 384417e1c7dSMatthias Springer SmallVector<Value> result; 385417e1c7dSMatthias Springer for (const auto &it : llvm::enumerate(values)) { 386417e1c7dSMatthias Springer size_t idx = it.index(); 387417e1c7dSMatthias Springer Value val = it.value(); 388417e1c7dSMatthias Springer result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); 389417e1c7dSMatthias Springer } 390417e1c7dSMatthias Springer return result; 391417e1c7dSMatthias Springer } 392417e1c7dSMatthias Springer 393417e1c7dSMatthias Springer /// Helper function for loop bufferization. Given a list of pre-bufferization 394417e1c7dSMatthias Springer /// yielded values, compute the list of bufferized yielded values. 395417e1c7dSMatthias Springer SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values, 396417e1c7dSMatthias Springer TypeRange bufferizedTypes, 397417e1c7dSMatthias Springer const DenseSet<int64_t> &tensorIndices, 398417e1c7dSMatthias Springer const DenseSet<int64_t> &equivalentTensors, 399417e1c7dSMatthias Springer BufferizationState &state) { 400417e1c7dSMatthias Springer return convertTensorValues( 401417e1c7dSMatthias Springer values, tensorIndices, [&](Value val, int64_t index) { 402417e1c7dSMatthias Springer return getYieldedBuffer(rewriter, val, 403417e1c7dSMatthias Springer bufferizedTypes[index].cast<BaseMemRefType>(), 404417e1c7dSMatthias Springer equivalentTensors.contains(index), state); 405417e1c7dSMatthias Springer }); 406417e1c7dSMatthias Springer } 407417e1c7dSMatthias Springer 408a5d09c63SMatthias Springer /// Helper function for loop bufferization. Given a list of bbArgs of the new 409a5d09c63SMatthias Springer /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 410a5d09c63SMatthias Springer /// ToTensorOps, so that the block body can be moved over to the new op. 411a5d09c63SMatthias Springer SmallVector<Value> 412a5d09c63SMatthias Springer getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 413a5d09c63SMatthias Springer const DenseSet<int64_t> &tensorIndices) { 414a5d09c63SMatthias Springer return convertTensorValues( 415a5d09c63SMatthias Springer bbArgs, tensorIndices, [&](Value val, int64_t index) { 416a5d09c63SMatthias Springer return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 417a5d09c63SMatthias Springer }); 418a5d09c63SMatthias Springer } 419a5d09c63SMatthias Springer 42019efe141SMatthias Springer /// Bufferization of scf.for. Replace with a new scf.for that operates on 42119efe141SMatthias Springer /// memrefs. 42219efe141SMatthias Springer struct ForOpInterface 42319efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<ForOpInterface, 42419efe141SMatthias Springer scf::ForOp> { 42519efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 4269597b16aSMatthias Springer const AnalysisState &state) const { 42719efe141SMatthias Springer // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 42819efe141SMatthias Springer // its matching bbArg may. 42919efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 43019efe141SMatthias Springer return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 43119efe141SMatthias Springer } 43219efe141SMatthias Springer 43319efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 4349597b16aSMatthias Springer const AnalysisState &state) const { 4351e1eeae8SMatthias Springer // Tensor iter_args of scf::ForOps are always considered as a write. 43619efe141SMatthias Springer return true; 43719efe141SMatthias Springer } 43819efe141SMatthias Springer 4399597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 4409597b16aSMatthias Springer const AnalysisState &state) const { 44119efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 442585a8a32SMatthias Springer return {forOp.getResultForOpOperand(opOperand)}; 44319efe141SMatthias Springer } 44419efe141SMatthias Springer 44519efe141SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 4469597b16aSMatthias Springer const AnalysisState &state) const { 44719efe141SMatthias Springer // ForOp results are equivalent to their corresponding init_args if the 44819efe141SMatthias Springer // corresponding iter_args and yield values are equivalent. 44919efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 45019efe141SMatthias Springer OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 45119efe141SMatthias Springer auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 4521e1eeae8SMatthias Springer auto yieldOp = 4531e1eeae8SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 45419efe141SMatthias Springer bool equivalentYield = state.areEquivalentBufferizedValues( 45519efe141SMatthias Springer bbArg, yieldOp->getOperand(opResult.getResultNumber())); 45619efe141SMatthias Springer return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 45719efe141SMatthias Springer } 45819efe141SMatthias Springer 45919efe141SMatthias Springer bool isWritable(Operation *op, Value value, 4609597b16aSMatthias Springer const AnalysisState &state) const { 46119efe141SMatthias Springer // Interestingly, scf::ForOp's bbArg can **always** be viewed 46219efe141SMatthias Springer // inplace from the perspective of ops nested under: 46319efe141SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an 46419efe141SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable. 46519efe141SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just 46619efe141SMatthias Springer // bufferizes to that too. 46719efe141SMatthias Springer return true; 46819efe141SMatthias Springer } 46919efe141SMatthias Springer 47019efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 4719597b16aSMatthias Springer BufferizationState &state) const { 47219efe141SMatthias Springer auto forOp = cast<scf::ForOp>(op); 473417e1c7dSMatthias Springer auto oldYieldOp = 474417e1c7dSMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 47519efe141SMatthias Springer Block *oldLoopBody = &forOp.getLoopBody().front(); 47619efe141SMatthias Springer 47719efe141SMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that 47819efe141SMatthias Springer // are bufferized. 479417e1c7dSMatthias Springer DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 4801e1eeae8SMatthias Springer // For every yielded value, is the value equivalent to its corresponding 4811e1eeae8SMatthias Springer // bbArg? 482417e1c7dSMatthias Springer DenseSet<int64_t> equivalentYields = 483417e1c7dSMatthias Springer getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), 484417e1c7dSMatthias Springer state.getAnalysisState()); 48519efe141SMatthias Springer 486417e1c7dSMatthias Springer // The new memref init_args of the loop. 487417e1c7dSMatthias Springer SmallVector<Value> initArgs = 488417e1c7dSMatthias Springer getBuffers(rewriter, forOp.getIterOpOperands(), state); 489417e1c7dSMatthias Springer if (initArgs.size() != indices.size()) 490417e1c7dSMatthias Springer return failure(); 49119efe141SMatthias Springer 49219efe141SMatthias Springer // Construct a new scf.for op with memref instead of tensor values. 49319efe141SMatthias Springer auto newForOp = rewriter.create<scf::ForOp>( 49419efe141SMatthias Springer forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 49519efe141SMatthias Springer forOp.getStep(), initArgs); 496417e1c7dSMatthias Springer ValueRange initArgsRange(initArgs); 497417e1c7dSMatthias Springer TypeRange initArgsTypes(initArgsRange); 49819efe141SMatthias Springer Block *loopBody = &newForOp.getLoopBody().front(); 49919efe141SMatthias Springer 50019efe141SMatthias Springer // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 50119efe141SMatthias Springer // iter_args of the new loop in ToTensorOps. 50219efe141SMatthias Springer rewriter.setInsertionPointToStart(loopBody); 503a5d09c63SMatthias Springer SmallVector<Value> iterArgs = 504a5d09c63SMatthias Springer getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 50519efe141SMatthias Springer iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 50619efe141SMatthias Springer 50719efe141SMatthias Springer // Erase terminator if present. 50819efe141SMatthias Springer if (iterArgs.size() == 1) 50919efe141SMatthias Springer rewriter.eraseOp(loopBody->getTerminator()); 51019efe141SMatthias Springer 51119efe141SMatthias Springer // Move loop body to new loop. 51219efe141SMatthias Springer rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 51319efe141SMatthias Springer 51419efe141SMatthias Springer // Update scf.yield of new loop. 51519efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 51619efe141SMatthias Springer rewriter.setInsertionPoint(yieldOp); 51719efe141SMatthias Springer SmallVector<Value> yieldValues = 518417e1c7dSMatthias Springer getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, 519417e1c7dSMatthias Springer equivalentYields, state); 52019efe141SMatthias Springer yieldOp.getResultsMutable().assign(yieldValues); 52119efe141SMatthias Springer 52219efe141SMatthias Springer // Replace loop results. 52319efe141SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 52419efe141SMatthias Springer 52519efe141SMatthias Springer return success(); 52619efe141SMatthias Springer } 5274ec00fb3SMatthias Springer 5281e1eeae8SMatthias Springer /// Assert that yielded values of an scf.for op are equivalent to their 529f178c386SMatthias Springer /// corresponding bbArgs. In that case, the buffer relations of the 530f178c386SMatthias Springer /// corresponding OpResults are "Equivalent". 531f178c386SMatthias Springer /// 532f178c386SMatthias Springer /// If this is not the case, an allocs+copies are inserted and yielded from 533f178c386SMatthias Springer /// the loop. This could be a performance problem, so it must be explicitly 534f178c386SMatthias Springer /// activated with `alloc-return-allocs`. 5354ec00fb3SMatthias Springer LogicalResult verifyAnalysis(Operation *op, 5369597b16aSMatthias Springer const AnalysisState &state) const { 5371e1eeae8SMatthias Springer const auto &options = 5381e1eeae8SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 5391e1eeae8SMatthias Springer if (options.allowReturnAllocs) 5401e1eeae8SMatthias Springer return success(); 5411e1eeae8SMatthias Springer 5424ec00fb3SMatthias Springer auto forOp = cast<scf::ForOp>(op); 5434ec00fb3SMatthias Springer auto yieldOp = 5444ec00fb3SMatthias Springer cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 545f178c386SMatthias Springer for (OpResult opResult : op->getOpResults()) { 546f178c386SMatthias Springer if (!opResult.getType().isa<TensorType>()) 5474ec00fb3SMatthias Springer continue; 5484ec00fb3SMatthias Springer 5494ec00fb3SMatthias Springer // Note: This is overly strict. We should check for aliasing bufferized 5504ec00fb3SMatthias Springer // values. But we don't have a "must-alias" analysis yet. 551f178c386SMatthias Springer if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 5524ec00fb3SMatthias Springer return yieldOp->emitError() 553f178c386SMatthias Springer << "Yield operand #" << opResult.getResultNumber() 554e3006825SMatthias Springer << " is not equivalent to the corresponding iter bbArg"; 5554ec00fb3SMatthias Springer } 556f178c386SMatthias Springer 5574ec00fb3SMatthias Springer return success(); 5584ec00fb3SMatthias Springer } 55919efe141SMatthias Springer }; 56019efe141SMatthias Springer 561a5d09c63SMatthias Springer /// Bufferization of scf.while. Replace with a new scf.while that operates on 562a5d09c63SMatthias Springer /// memrefs. 563a5d09c63SMatthias Springer struct WhileOpInterface 564a5d09c63SMatthias Springer : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 565a5d09c63SMatthias Springer scf::WhileOp> { 566a5d09c63SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 567a5d09c63SMatthias Springer const AnalysisState &state) const { 568a5d09c63SMatthias Springer // Tensor iter_args of scf::WhileOps are always considered as a read. 569a5d09c63SMatthias Springer return true; 570a5d09c63SMatthias Springer } 571a5d09c63SMatthias Springer 572a5d09c63SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 573a5d09c63SMatthias Springer const AnalysisState &state) const { 574a5d09c63SMatthias Springer // Tensor iter_args of scf::WhileOps are always considered as a write. 575a5d09c63SMatthias Springer return true; 576a5d09c63SMatthias Springer } 577a5d09c63SMatthias Springer 578a5d09c63SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 579a5d09c63SMatthias Springer const AnalysisState &state) const { 580a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 581a5d09c63SMatthias Springer return {whileOp->getResult(opOperand.getOperandNumber())}; 582a5d09c63SMatthias Springer } 583a5d09c63SMatthias Springer 584a5d09c63SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 585a5d09c63SMatthias Springer const AnalysisState &state) const { 586a5d09c63SMatthias Springer // WhileOp results are equivalent to their corresponding init_args if the 587a5d09c63SMatthias Springer // corresponding iter_args and yield values are equivalent (for both the 588a5d09c63SMatthias Springer // "before" and the "after" block). 589a5d09c63SMatthias Springer unsigned int resultNumber = opResult.getResultNumber(); 590a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 591a5d09c63SMatthias Springer 592a5d09c63SMatthias Springer auto conditionOp = whileOp.getConditionOp(); 593a5d09c63SMatthias Springer BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 594a5d09c63SMatthias Springer Value conditionOperand = conditionOp.getArgs()[resultNumber]; 595a5d09c63SMatthias Springer bool equivCondition = 596a5d09c63SMatthias Springer state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 597a5d09c63SMatthias Springer 598a5d09c63SMatthias Springer auto yieldOp = whileOp.getYieldOp(); 599a5d09c63SMatthias Springer BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 600a5d09c63SMatthias Springer Value yieldOperand = yieldOp.getOperand(resultNumber); 601a5d09c63SMatthias Springer bool equivYield = 602a5d09c63SMatthias Springer state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 603a5d09c63SMatthias Springer 604a5d09c63SMatthias Springer return equivCondition && equivYield ? BufferRelation::Equivalent 605a5d09c63SMatthias Springer : BufferRelation::None; 606a5d09c63SMatthias Springer } 607a5d09c63SMatthias Springer 608a5d09c63SMatthias Springer bool isWritable(Operation *op, Value value, 609a5d09c63SMatthias Springer const AnalysisState &state) const { 610a5d09c63SMatthias Springer // Interestingly, scf::WhileOp's bbArg can **always** be viewed 611a5d09c63SMatthias Springer // inplace from the perspective of ops nested under: 612a5d09c63SMatthias Springer // 1. Either the matching iter operand is not bufferized inplace and an 613a5d09c63SMatthias Springer // alloc + optional copy makes the bbArg itself inplaceable. 614a5d09c63SMatthias Springer // 2. Or the matching iter operand is bufferized inplace and bbArg just 615a5d09c63SMatthias Springer // bufferizes to that too. 616a5d09c63SMatthias Springer return true; 617a5d09c63SMatthias Springer } 618a5d09c63SMatthias Springer 619a5d09c63SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 620a5d09c63SMatthias Springer BufferizationState &state) const { 621a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 622a5d09c63SMatthias Springer 623a5d09c63SMatthias Springer assert(whileOp.getBefore().getBlocks().size() == 1 && 624a5d09c63SMatthias Springer "regions with multiple blocks not supported"); 625a5d09c63SMatthias Springer Block *beforeBody = &whileOp.getBefore().front(); 626a5d09c63SMatthias Springer assert(whileOp.getAfter().getBlocks().size() == 1 && 627a5d09c63SMatthias Springer "regions with multiple blocks not supported"); 628a5d09c63SMatthias Springer Block *afterBody = &whileOp.getAfter().front(); 629a5d09c63SMatthias Springer 630a5d09c63SMatthias Springer // Indices of all iter_args that have tensor type. These are the ones that 631a5d09c63SMatthias Springer // are bufferized. 632a5d09c63SMatthias Springer DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits()); 633a5d09c63SMatthias Springer // For every yielded value, is the value equivalent to its corresponding 634a5d09c63SMatthias Springer // bbArg? 635a5d09c63SMatthias Springer DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 636a5d09c63SMatthias Springer whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), 637a5d09c63SMatthias Springer state.getAnalysisState()); 638a5d09c63SMatthias Springer DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 639a5d09c63SMatthias Springer whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), 640a5d09c63SMatthias Springer state.getAnalysisState()); 641a5d09c63SMatthias Springer 642a5d09c63SMatthias Springer // The new memref init_args of the loop. 643a5d09c63SMatthias Springer SmallVector<Value> initArgs = 644a5d09c63SMatthias Springer getBuffers(rewriter, whileOp->getOpOperands(), state); 645a5d09c63SMatthias Springer if (initArgs.size() != indices.size()) 646a5d09c63SMatthias Springer return failure(); 647a5d09c63SMatthias Springer 648a5d09c63SMatthias Springer // Construct a new scf.while op with memref instead of tensor values. 649a5d09c63SMatthias Springer ValueRange argsRange(initArgs); 650a5d09c63SMatthias Springer TypeRange argsTypes(argsRange); 651a5d09c63SMatthias Springer auto newWhileOp = 652a5d09c63SMatthias Springer rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs); 653a5d09c63SMatthias Springer // Add before/after regions to the new op. 654a5d09c63SMatthias Springer SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc()); 655a5d09c63SMatthias Springer Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 656a5d09c63SMatthias Springer newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs); 657a5d09c63SMatthias Springer Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 658a5d09c63SMatthias Springer newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs); 659a5d09c63SMatthias Springer 660a5d09c63SMatthias Springer // Set up new iter_args and move the loop condition block to the new op. 661a5d09c63SMatthias Springer // The old block uses tensors, so wrap the (memref) bbArgs of the new block 662a5d09c63SMatthias Springer // in ToTensorOps. 663a5d09c63SMatthias Springer rewriter.setInsertionPointToStart(newBeforeBody); 664a5d09c63SMatthias Springer SmallVector<Value> newBeforeArgs = getBbArgReplacements( 665a5d09c63SMatthias Springer rewriter, newWhileOp.getBeforeArguments(), indices); 666a5d09c63SMatthias Springer rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 667a5d09c63SMatthias Springer 668a5d09c63SMatthias Springer // Update scf.condition of new loop. 669a5d09c63SMatthias Springer auto newConditionOp = newWhileOp.getConditionOp(); 670a5d09c63SMatthias Springer rewriter.setInsertionPoint(newConditionOp); 671a5d09c63SMatthias Springer SmallVector<Value> newConditionArgs = 672a5d09c63SMatthias Springer getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices, 673a5d09c63SMatthias Springer equivalentYieldsBefore, state); 674a5d09c63SMatthias Springer newConditionOp.getArgsMutable().assign(newConditionArgs); 675a5d09c63SMatthias Springer 676a5d09c63SMatthias Springer // Set up new iter_args and move the loop body block to the new op. 677a5d09c63SMatthias Springer // The old block uses tensors, so wrap the (memref) bbArgs of the new block 678a5d09c63SMatthias Springer // in ToTensorOps. 679a5d09c63SMatthias Springer rewriter.setInsertionPointToStart(newAfterBody); 680a5d09c63SMatthias Springer SmallVector<Value> newAfterArgs = 681a5d09c63SMatthias Springer getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices); 682a5d09c63SMatthias Springer rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 683a5d09c63SMatthias Springer 684a5d09c63SMatthias Springer // Update scf.yield of the new loop. 685a5d09c63SMatthias Springer auto newYieldOp = newWhileOp.getYieldOp(); 686a5d09c63SMatthias Springer rewriter.setInsertionPoint(newYieldOp); 687a5d09c63SMatthias Springer SmallVector<Value> newYieldValues = 688a5d09c63SMatthias Springer getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices, 689a5d09c63SMatthias Springer equivalentYieldsAfter, state); 690a5d09c63SMatthias Springer newYieldOp.getResultsMutable().assign(newYieldValues); 691a5d09c63SMatthias Springer 692a5d09c63SMatthias Springer // Replace loop results. 693a5d09c63SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 694a5d09c63SMatthias Springer 695a5d09c63SMatthias Springer return success(); 696a5d09c63SMatthias Springer } 697a5d09c63SMatthias Springer 698a5d09c63SMatthias Springer /// Assert that yielded values of an scf.while op are equivalent to their 699a5d09c63SMatthias Springer /// corresponding bbArgs. In that case, the buffer relations of the 700a5d09c63SMatthias Springer /// corresponding OpResults are "Equivalent". 701a5d09c63SMatthias Springer /// 702a5d09c63SMatthias Springer /// If this is not the case, allocs+copies are inserted and yielded from 703a5d09c63SMatthias Springer /// the loop. This could be a performance problem, so it must be explicitly 704a5d09c63SMatthias Springer /// activated with `alloc-return-allocs`. 705a5d09c63SMatthias Springer /// 706a5d09c63SMatthias Springer /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 707a5d09c63SMatthias Springer /// equivalence condition must be checked for both. 708a5d09c63SMatthias Springer LogicalResult verifyAnalysis(Operation *op, 709a5d09c63SMatthias Springer const AnalysisState &state) const { 710a5d09c63SMatthias Springer auto whileOp = cast<scf::WhileOp>(op); 711a5d09c63SMatthias Springer const auto &options = 712a5d09c63SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 713a5d09c63SMatthias Springer if (options.allowReturnAllocs) 714a5d09c63SMatthias Springer return success(); 715a5d09c63SMatthias Springer 716a5d09c63SMatthias Springer auto conditionOp = whileOp.getConditionOp(); 717a5d09c63SMatthias Springer for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 718a5d09c63SMatthias Springer if (!it.value().getType().isa<TensorType>()) 719a5d09c63SMatthias Springer continue; 720a5d09c63SMatthias Springer if (!state.areEquivalentBufferizedValues( 721a5d09c63SMatthias Springer it.value(), conditionOp->getBlock()->getArgument(it.index()))) 722a5d09c63SMatthias Springer return conditionOp->emitError() 723a5d09c63SMatthias Springer << "Condition arg #" << it.index() 724a5d09c63SMatthias Springer << " is not equivalent to the corresponding iter bbArg"; 725a5d09c63SMatthias Springer } 726a5d09c63SMatthias Springer 727a5d09c63SMatthias Springer auto yieldOp = whileOp.getYieldOp(); 728a5d09c63SMatthias Springer for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 729a5d09c63SMatthias Springer if (!it.value().getType().isa<TensorType>()) 730a5d09c63SMatthias Springer continue; 731a5d09c63SMatthias Springer if (!state.areEquivalentBufferizedValues( 732a5d09c63SMatthias Springer it.value(), yieldOp->getBlock()->getArgument(it.index()))) 733a5d09c63SMatthias Springer return yieldOp->emitError() 734a5d09c63SMatthias Springer << "Yield operand #" << it.index() 735a5d09c63SMatthias Springer << " is not equivalent to the corresponding iter bbArg"; 736a5d09c63SMatthias Springer } 737a5d09c63SMatthias Springer 738a5d09c63SMatthias Springer return success(); 739a5d09c63SMatthias Springer } 740a5d09c63SMatthias Springer }; 741a5d09c63SMatthias Springer 74219efe141SMatthias Springer /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 74319efe141SMatthias Springer /// this is for analysis only. 74419efe141SMatthias Springer struct YieldOpInterface 74519efe141SMatthias Springer : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 74619efe141SMatthias Springer scf::YieldOp> { 74719efe141SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 7489597b16aSMatthias Springer const AnalysisState &state) const { 74919efe141SMatthias Springer return true; 75019efe141SMatthias Springer } 75119efe141SMatthias Springer 75219efe141SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 7539597b16aSMatthias Springer const AnalysisState &state) const { 75419efe141SMatthias Springer return false; 75519efe141SMatthias Springer } 75619efe141SMatthias Springer 7579597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 7589597b16aSMatthias Springer const AnalysisState &state) const { 75919efe141SMatthias Springer if (isa<scf::IfOp>(op->getParentOp())) 760585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 76119efe141SMatthias Springer if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 762585a8a32SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 763585a8a32SMatthias Springer return {}; 76419efe141SMatthias Springer } 76519efe141SMatthias Springer 76619efe141SMatthias Springer bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 7679597b16aSMatthias Springer const AnalysisState &state) const { 76819efe141SMatthias Springer // Yield operands always bufferize inplace. Otherwise, an alloc + copy 76919efe141SMatthias Springer // may be generated inside the block. We should not return/yield allocations 77019efe141SMatthias Springer // when possible. 77119efe141SMatthias Springer return true; 77219efe141SMatthias Springer } 77319efe141SMatthias Springer 77419efe141SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 7759597b16aSMatthias Springer BufferizationState &state) const { 77619efe141SMatthias Springer auto yieldOp = cast<scf::YieldOp>(op); 777a5d09c63SMatthias Springer if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 77819efe141SMatthias Springer yieldOp->getParentOp())) 77919efe141SMatthias Springer return yieldOp->emitError("unsupported scf::YieldOp parent"); 78019efe141SMatthias Springer return success(); 78119efe141SMatthias Springer } 78219efe141SMatthias Springer }; 78319efe141SMatthias Springer 78419efe141SMatthias Springer } // namespace 78519efe141SMatthias Springer } // namespace scf 78619efe141SMatthias Springer } // namespace mlir 78719efe141SMatthias Springer 78819efe141SMatthias Springer void mlir::scf::registerBufferizableOpInterfaceExternalModels( 78919efe141SMatthias Springer DialectRegistry ®istry) { 79077eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 79177eee579SRiver Riddle ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 79277eee579SRiver Riddle ForOp::attachInterface<ForOpInterface>(*ctx); 79377eee579SRiver Riddle IfOp::attachInterface<IfOpInterface>(*ctx); 794a5d09c63SMatthias Springer WhileOp::attachInterface<WhileOpInterface>(*ctx); 79577eee579SRiver Riddle YieldOp::attachInterface<YieldOpInterface>(*ctx); 79677eee579SRiver Riddle }); 79719efe141SMatthias Springer } 798