1ae1ea0beSJulian Gross //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2ae1ea0beSJulian Gross //
3ae1ea0beSJulian Gross // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ae1ea0beSJulian Gross // See https://llvm.org/LICENSE.txt for license information.
5ae1ea0beSJulian Gross // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ae1ea0beSJulian Gross //
7ae1ea0beSJulian Gross //===----------------------------------------------------------------------===//
8ae1ea0beSJulian Gross //
9ae1ea0beSJulian Gross // This file implements patterns to convert Bufferization dialect to MemRef
10ae1ea0beSJulian Gross // dialect.
11ae1ea0beSJulian Gross //
12ae1ea0beSJulian Gross //===----------------------------------------------------------------------===//
13ae1ea0beSJulian Gross 
14ae1ea0beSJulian Gross #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
15*b7f93c28SJeff Niu #include "../PassDetail.h"
16ae1ea0beSJulian Gross #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17ae1ea0beSJulian Gross #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18ae1ea0beSJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
19ae1ea0beSJulian Gross #include "mlir/IR/BuiltinTypes.h"
20ae1ea0beSJulian Gross #include "mlir/Support/LogicalResult.h"
21ae1ea0beSJulian Gross #include "mlir/Transforms/DialectConversion.h"
22ae1ea0beSJulian Gross 
23ae1ea0beSJulian Gross using namespace mlir;
24ae1ea0beSJulian Gross 
25ae1ea0beSJulian Gross namespace {
26ae1ea0beSJulian Gross /// The CloneOpConversion transforms all bufferization clone operations into
27ae1ea0beSJulian Gross /// memref alloc and memref copy operations. In the dynamic-shape case, it also
28ae1ea0beSJulian Gross /// emits additional dim and constant operations to determine the shape. This
29ae1ea0beSJulian Gross /// conversion does not resolve memory leaks if it is used alone.
30ae1ea0beSJulian Gross struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
31ae1ea0beSJulian Gross   using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
32ae1ea0beSJulian Gross 
33ae1ea0beSJulian Gross   LogicalResult
matchAndRewrite__anonf46417ab0111::CloneOpConversion34ae1ea0beSJulian Gross   matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
35ae1ea0beSJulian Gross                   ConversionPatternRewriter &rewriter) const override {
36ae1ea0beSJulian Gross     // Check for unranked memref types which are currently not supported.
37ae1ea0beSJulian Gross     Type type = op.getType();
38ae1ea0beSJulian Gross     if (type.isa<UnrankedMemRefType>()) {
39ae1ea0beSJulian Gross       return rewriter.notifyMatchFailure(
40ae1ea0beSJulian Gross           op, "UnrankedMemRefType is not supported.");
41ae1ea0beSJulian Gross     }
428dca38d5SMatthias Springer     MemRefType memrefType = type.cast<MemRefType>();
438dca38d5SMatthias Springer     MemRefLayoutAttrInterface layout;
448dca38d5SMatthias Springer     auto allocType =
458dca38d5SMatthias Springer         MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
468dca38d5SMatthias Springer                         layout, memrefType.getMemorySpace());
478dca38d5SMatthias Springer     // Since this implementation always allocates, certain result types of the
488dca38d5SMatthias Springer     // clone op cannot be lowered.
498dca38d5SMatthias Springer     if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
508dca38d5SMatthias Springer       return failure();
51ae1ea0beSJulian Gross 
52ae1ea0beSJulian Gross     // Transform a clone operation into alloc + copy operation and pay
53ae1ea0beSJulian Gross     // attention to the shape dimensions.
54ae1ea0beSJulian Gross     Location loc = op->getLoc();
55ae1ea0beSJulian Gross     SmallVector<Value, 4> dynamicOperands;
56ae1ea0beSJulian Gross     for (int i = 0; i < memrefType.getRank(); ++i) {
57ae1ea0beSJulian Gross       if (!memrefType.isDynamicDim(i))
58ae1ea0beSJulian Gross         continue;
59ae1ea0beSJulian Gross       Value size = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
6099260e95SMatthias Springer       Value dim =
6199260e95SMatthias Springer           rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), size);
62ae1ea0beSJulian Gross       dynamicOperands.push_back(dim);
63ae1ea0beSJulian Gross     }
648dca38d5SMatthias Springer 
658dca38d5SMatthias Springer     // Allocate a memref with identity layout.
668dca38d5SMatthias Springer     Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
67ae1ea0beSJulian Gross                                                    dynamicOperands);
688dca38d5SMatthias Springer     // Cast the allocation to the specified type if needed.
698dca38d5SMatthias Springer     if (memrefType != allocType)
708dca38d5SMatthias Springer       alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
718dca38d5SMatthias Springer     rewriter.replaceOp(op, alloc);
7299260e95SMatthias Springer     rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
73ae1ea0beSJulian Gross     return success();
74ae1ea0beSJulian Gross   }
75ae1ea0beSJulian Gross };
76ae1ea0beSJulian Gross } // namespace
77ae1ea0beSJulian Gross 
populateBufferizationToMemRefConversionPatterns(RewritePatternSet & patterns)78ae1ea0beSJulian Gross void mlir::populateBufferizationToMemRefConversionPatterns(
79ae1ea0beSJulian Gross     RewritePatternSet &patterns) {
80ae1ea0beSJulian Gross   patterns.add<CloneOpConversion>(patterns.getContext());
81ae1ea0beSJulian Gross }
82ae1ea0beSJulian Gross 
83ae1ea0beSJulian Gross namespace {
84ae1ea0beSJulian Gross struct BufferizationToMemRefPass
85ae1ea0beSJulian Gross     : public ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
86ae1ea0beSJulian Gross   BufferizationToMemRefPass() = default;
87ae1ea0beSJulian Gross 
runOnOperation__anonf46417ab0211::BufferizationToMemRefPass88ae1ea0beSJulian Gross   void runOnOperation() override {
89ae1ea0beSJulian Gross     RewritePatternSet patterns(&getContext());
90ae1ea0beSJulian Gross     populateBufferizationToMemRefConversionPatterns(patterns);
91ae1ea0beSJulian Gross 
92ae1ea0beSJulian Gross     ConversionTarget target(getContext());
93ae1ea0beSJulian Gross     target.addLegalDialect<memref::MemRefDialect>();
94ae1ea0beSJulian Gross     target.addLegalOp<arith::ConstantOp>();
95ae1ea0beSJulian Gross     target.addIllegalDialect<bufferization::BufferizationDialect>();
96ae1ea0beSJulian Gross 
97ae1ea0beSJulian Gross     if (failed(applyPartialConversion(getOperation(), target,
98ae1ea0beSJulian Gross                                       std::move(patterns))))
99ae1ea0beSJulian Gross       signalPassFailure();
100ae1ea0beSJulian Gross   }
101ae1ea0beSJulian Gross };
102ae1ea0beSJulian Gross } // namespace
103ae1ea0beSJulian Gross 
createBufferizationToMemRefPass()104ae1ea0beSJulian Gross std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
105ae1ea0beSJulian Gross   return std::make_unique<BufferizationToMemRefPass>();
106ae1ea0beSJulian Gross }
107