193e66327SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
293e66327SMatthias Springer //
393e66327SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
493e66327SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
593e66327SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
693e66327SMatthias Springer //
793e66327SMatthias Springer //===----------------------------------------------------------------------===//
893e66327SMatthias Springer
993e66327SMatthias Springer #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
1093e66327SMatthias Springer
1193e66327SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1293e66327SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1393e66327SMatthias Springer #include "mlir/Dialect/Shape/IR/Shape.h"
1493e66327SMatthias Springer #include "mlir/IR/Dialect.h"
1593e66327SMatthias Springer #include "mlir/IR/Operation.h"
1693e66327SMatthias Springer #include "mlir/IR/PatternMatch.h"
1793e66327SMatthias Springer
1893e66327SMatthias Springer using namespace mlir;
1993e66327SMatthias Springer using namespace mlir::bufferization;
2093e66327SMatthias Springer using namespace mlir::shape;
2193e66327SMatthias Springer
2293e66327SMatthias Springer namespace mlir {
2393e66327SMatthias Springer namespace shape {
2493e66327SMatthias Springer namespace {
2593e66327SMatthias Springer
2693e66327SMatthias Springer /// Bufferization of shape.assuming.
2793e66327SMatthias Springer struct AssumingOpInterface
2893e66327SMatthias Springer : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
2993e66327SMatthias Springer shape::AssumingOp> {
3093e66327SMatthias Springer SmallVector<OpOperand *>
getAliasingOpOperandmlir::shape::__anon2c7d7a2f0111::AssumingOpInterface3193e66327SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult,
329597b16aSMatthias Springer const AnalysisState &state) const {
3393e66327SMatthias Springer // AssumingOps do not have tensor OpOperands. The yielded value can be any
3493e66327SMatthias Springer // SSA value that is in scope. To allow for use-def chain traversal through
3593e66327SMatthias Springer // AssumingOps in the analysis, the corresponding yield value is considered
3693e66327SMatthias Springer // to be aliasing with the result.
3793e66327SMatthias Springer auto assumingOp = cast<shape::AssumingOp>(op);
3893e66327SMatthias Springer size_t resultNum = std::distance(op->getOpResults().begin(),
3993e66327SMatthias Springer llvm::find(op->getOpResults(), opResult));
4093e66327SMatthias Springer // TODO: Support multiple blocks.
4193e66327SMatthias Springer assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
4293e66327SMatthias Springer "expected exactly 1 block");
4393e66327SMatthias Springer auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
4493e66327SMatthias Springer assumingOp.getDoRegion().front().getTerminator());
4593e66327SMatthias Springer assert(yieldOp && "expected shape.assuming_yield terminator");
4693e66327SMatthias Springer return {&yieldOp->getOpOperand(resultNum)};
4793e66327SMatthias Springer }
4893e66327SMatthias Springer
4993e66327SMatthias Springer // TODO: For better bufferization results, this could return `true` only if
5093e66327SMatthias Springer // there is a memory write in the region.
isMemoryWritemlir::shape::__anon2c7d7a2f0111::AssumingOpInterface5193e66327SMatthias Springer bool isMemoryWrite(Operation *op, OpResult opResult,
529597b16aSMatthias Springer const AnalysisState &state) const {
5393e66327SMatthias Springer // Similar to scf.if, results of this op are always considered memory writes
5493e66327SMatthias Springer // in the analysis. This is a useful pattern for all ops that have tensor
5593e66327SMatthias Springer // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
5693e66327SMatthias Springer // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
5793e66327SMatthias Springer // ops without OpOperands.
5893e66327SMatthias Springer return true;
5993e66327SMatthias Springer }
6093e66327SMatthias Springer
bufferizemlir::shape::__anon2c7d7a2f0111::AssumingOpInterface6193e66327SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
62b55d55ecSMatthias Springer const BufferizationOptions &options) const {
6393e66327SMatthias Springer auto assumingOp = cast<shape::AssumingOp>(op);
6419efb84cSMatthias Springer assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
6519efb84cSMatthias Springer "only 1 block supported");
6619efb84cSMatthias Springer auto yieldOp = cast<shape::AssumingYieldOp>(
6719efb84cSMatthias Springer assumingOp.getDoRegion().front().getTerminator());
6893e66327SMatthias Springer
6993e66327SMatthias Springer // Create new op and move over region.
7019efb84cSMatthias Springer TypeRange newResultTypes(yieldOp.operands());
7193e66327SMatthias Springer auto newOp = rewriter.create<shape::AssumingOp>(
7293e66327SMatthias Springer op->getLoc(), newResultTypes, assumingOp.getWitness());
7393e66327SMatthias Springer newOp.getDoRegion().takeBody(assumingOp.getRegion());
7493e66327SMatthias Springer
7593e66327SMatthias Springer // Update all uses of the old op.
7693e66327SMatthias Springer rewriter.setInsertionPointAfter(newOp);
7793e66327SMatthias Springer SmallVector<Value> newResults;
7893e66327SMatthias Springer for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
7993e66327SMatthias Springer if (it.value().isa<TensorType>()) {
8093e66327SMatthias Springer newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
8193e66327SMatthias Springer assumingOp.getLoc(), newOp->getResult(it.index())));
8293e66327SMatthias Springer } else {
8393e66327SMatthias Springer newResults.push_back(newOp->getResult(it.index()));
8493e66327SMatthias Springer }
8593e66327SMatthias Springer }
8693e66327SMatthias Springer
8793e66327SMatthias Springer // Replace old op.
8893e66327SMatthias Springer rewriter.replaceOp(assumingOp, newResults);
8993e66327SMatthias Springer
9093e66327SMatthias Springer return success();
9193e66327SMatthias Springer }
9293e66327SMatthias Springer
bufferRelationmlir::shape::__anon2c7d7a2f0111::AssumingOpInterface9393e66327SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult,
949597b16aSMatthias Springer const AnalysisState &state) const {
9593e66327SMatthias Springer return BufferRelation::Equivalent;
9693e66327SMatthias Springer }
9793e66327SMatthias Springer };
9893e66327SMatthias Springer
9993e66327SMatthias Springer /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
10093e66327SMatthias Springer /// ops, so this is for analysis only.
10193e66327SMatthias Springer struct AssumingYieldOpInterface
10293e66327SMatthias Springer : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
1036ab1ed43SMatthias Springer shape::AssumingYieldOp> {
bufferizesToMemoryReadmlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface10493e66327SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1059597b16aSMatthias Springer const AnalysisState &state) const {
10693e66327SMatthias Springer return true;
10793e66327SMatthias Springer }
10893e66327SMatthias Springer
bufferizesToMemoryWritemlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface10993e66327SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1109597b16aSMatthias Springer const AnalysisState &state) const {
11193e66327SMatthias Springer return false;
11293e66327SMatthias Springer }
11393e66327SMatthias Springer
getAliasingOpResultmlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface1149597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
1159597b16aSMatthias Springer const AnalysisState &state) const {
11693e66327SMatthias Springer assert(isa<shape::AssumingOp>(op->getParentOp()) &&
11793e66327SMatthias Springer "expected that parent is an AssumingOp");
11893e66327SMatthias Springer return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
11993e66327SMatthias Springer }
12093e66327SMatthias Springer
mustBufferizeInPlacemlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface12193e66327SMatthias Springer bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1229597b16aSMatthias Springer const AnalysisState &state) const {
12393e66327SMatthias Springer // Yield operands always bufferize inplace. Otherwise, an alloc + copy
12493e66327SMatthias Springer // may be generated inside the block. We should not return/yield allocations
12593e66327SMatthias Springer // when possible.
12693e66327SMatthias Springer return true;
12793e66327SMatthias Springer }
12893e66327SMatthias Springer
bufferizemlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface12993e66327SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130b55d55ecSMatthias Springer const BufferizationOptions &options) const {
13119efb84cSMatthias Springer auto yieldOp = cast<shape::AssumingYieldOp>(op);
13219efb84cSMatthias Springer SmallVector<Value> newResults;
133*5d50f51cSMatthias Springer for (Value value : yieldOp.operands()) {
134*5d50f51cSMatthias Springer if (value.getType().isa<TensorType>()) {
135*5d50f51cSMatthias Springer FailureOr<Value> buffer = getBuffer(rewriter, value, options);
136*5d50f51cSMatthias Springer if (failed(buffer))
137*5d50f51cSMatthias Springer return failure();
138*5d50f51cSMatthias Springer newResults.push_back(*buffer);
139*5d50f51cSMatthias Springer } else {
140*5d50f51cSMatthias Springer newResults.push_back(value);
141*5d50f51cSMatthias Springer }
142*5d50f51cSMatthias Springer }
14319efb84cSMatthias Springer replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
14419efb84cSMatthias Springer newResults);
145ba9d886dSMatthias Springer return success();
14693e66327SMatthias Springer }
14793e66327SMatthias Springer };
14893e66327SMatthias Springer
14993e66327SMatthias Springer } // namespace
15093e66327SMatthias Springer } // namespace shape
15193e66327SMatthias Springer } // namespace mlir
15293e66327SMatthias Springer
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)15393e66327SMatthias Springer void mlir::shape::registerBufferizableOpInterfaceExternalModels(
15493e66327SMatthias Springer DialectRegistry ®istry) {
15577eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
15677eee579SRiver Riddle shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
15777eee579SRiver Riddle shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
15877eee579SRiver Riddle });
15993e66327SMatthias Springer }
160