1f89bb3c0SAlexander Belyaev //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2f89bb3c0SAlexander Belyaev //
3f89bb3c0SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f89bb3c0SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5f89bb3c0SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f89bb3c0SAlexander Belyaev //
7f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
8f89bb3c0SAlexander Belyaev
9f89bb3c0SAlexander Belyaev #include "PassDetail.h"
10f89bb3c0SAlexander Belyaev
117a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14d2dacde5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
19eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
20f89bb3c0SAlexander Belyaev #include "mlir/IR/Operation.h"
21d2dacde5SMatthias Springer #include "mlir/Pass/PassManager.h"
227a1579acSMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23d2dacde5SMatthias Springer #include "mlir/Transforms/Passes.h"
24f89bb3c0SAlexander Belyaev
25f89bb3c0SAlexander Belyaev using namespace mlir;
26f89bb3c0SAlexander Belyaev using namespace mlir::bufferization;
27f89bb3c0SAlexander Belyaev
28f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
29f89bb3c0SAlexander Belyaev // BufferizeTypeConverter
30f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
31f89bb3c0SAlexander Belyaev
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)32f89bb3c0SAlexander Belyaev static Value materializeToTensor(OpBuilder &builder, TensorType type,
33f89bb3c0SAlexander Belyaev ValueRange inputs, Location loc) {
34f89bb3c0SAlexander Belyaev assert(inputs.size() == 1);
35f89bb3c0SAlexander Belyaev assert(inputs[0].getType().isa<BaseMemRefType>());
36f89bb3c0SAlexander Belyaev return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
37f89bb3c0SAlexander Belyaev }
38f89bb3c0SAlexander Belyaev
39f89bb3c0SAlexander Belyaev /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()40f89bb3c0SAlexander Belyaev BufferizeTypeConverter::BufferizeTypeConverter() {
41f89bb3c0SAlexander Belyaev // Keep all types unchanged.
42f89bb3c0SAlexander Belyaev addConversion([](Type type) { return type; });
43f89bb3c0SAlexander Belyaev // Convert RankedTensorType to MemRefType.
44f89bb3c0SAlexander Belyaev addConversion([](RankedTensorType type) -> Type {
45f89bb3c0SAlexander Belyaev return MemRefType::get(type.getShape(), type.getElementType());
46f89bb3c0SAlexander Belyaev });
47f89bb3c0SAlexander Belyaev // Convert UnrankedTensorType to UnrankedMemRefType.
48f89bb3c0SAlexander Belyaev addConversion([](UnrankedTensorType type) -> Type {
49f89bb3c0SAlexander Belyaev return UnrankedMemRefType::get(type.getElementType(), 0);
50f89bb3c0SAlexander Belyaev });
51f89bb3c0SAlexander Belyaev addArgumentMaterialization(materializeToTensor);
52f89bb3c0SAlexander Belyaev addSourceMaterialization(materializeToTensor);
53f89bb3c0SAlexander Belyaev addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
54f89bb3c0SAlexander Belyaev ValueRange inputs, Location loc) -> Value {
55fa7c8cb4SMatthias Springer assert(inputs.size() == 1 && "expected exactly one input");
56fa7c8cb4SMatthias Springer
57fa7c8cb4SMatthias Springer if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
58fa7c8cb4SMatthias Springer // MemRef to MemRef cast.
59fa7c8cb4SMatthias Springer assert(inputType != type && "expected different types");
60fa7c8cb4SMatthias Springer // Unranked to ranked and ranked to unranked casts must be explicit.
61fa7c8cb4SMatthias Springer auto rankedDestType = type.dyn_cast<MemRefType>();
62fa7c8cb4SMatthias Springer if (!rankedDestType)
63fa7c8cb4SMatthias Springer return nullptr;
64fa7c8cb4SMatthias Springer FailureOr<Value> replacement =
65fa7c8cb4SMatthias Springer castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
66fa7c8cb4SMatthias Springer if (failed(replacement))
67fa7c8cb4SMatthias Springer return nullptr;
68fa7c8cb4SMatthias Springer return *replacement;
69fa7c8cb4SMatthias Springer }
70fa7c8cb4SMatthias Springer
71fa7c8cb4SMatthias Springer if (inputs[0].getType().isa<TensorType>()) {
72fa7c8cb4SMatthias Springer // Tensor to MemRef cast.
73f89bb3c0SAlexander Belyaev return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
74fa7c8cb4SMatthias Springer }
75fa7c8cb4SMatthias Springer
76fa7c8cb4SMatthias Springer llvm_unreachable("only tensor/memref input types supported");
77f89bb3c0SAlexander Belyaev });
78f89bb3c0SAlexander Belyaev }
79f89bb3c0SAlexander Belyaev
populateBufferizeMaterializationLegality(ConversionTarget & target)80f89bb3c0SAlexander Belyaev void mlir::bufferization::populateBufferizeMaterializationLegality(
81f89bb3c0SAlexander Belyaev ConversionTarget &target) {
82f89bb3c0SAlexander Belyaev target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
83f89bb3c0SAlexander Belyaev }
84f89bb3c0SAlexander Belyaev
85f89bb3c0SAlexander Belyaev namespace {
86f89bb3c0SAlexander Belyaev // In a finalizing bufferize conversion, we know that all tensors have been
87f89bb3c0SAlexander Belyaev // converted to memrefs, thus, this op becomes an identity.
88f89bb3c0SAlexander Belyaev class BufferizeToTensorOp
89f89bb3c0SAlexander Belyaev : public OpConversionPattern<bufferization::ToTensorOp> {
90f89bb3c0SAlexander Belyaev public:
91f89bb3c0SAlexander Belyaev using OpConversionPattern::OpConversionPattern;
92f89bb3c0SAlexander Belyaev LogicalResult
matchAndRewrite(bufferization::ToTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const93f89bb3c0SAlexander Belyaev matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
94f89bb3c0SAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
9599260e95SMatthias Springer rewriter.replaceOp(op, adaptor.getMemref());
96f89bb3c0SAlexander Belyaev return success();
97f89bb3c0SAlexander Belyaev }
98f89bb3c0SAlexander Belyaev };
99f89bb3c0SAlexander Belyaev } // namespace
100f89bb3c0SAlexander Belyaev
101f89bb3c0SAlexander Belyaev namespace {
102f89bb3c0SAlexander Belyaev // In a finalizing bufferize conversion, we know that all tensors have been
103f89bb3c0SAlexander Belyaev // converted to memrefs, thus, this op becomes an identity.
104f89bb3c0SAlexander Belyaev class BufferizeToMemrefOp
105f89bb3c0SAlexander Belyaev : public OpConversionPattern<bufferization::ToMemrefOp> {
106f89bb3c0SAlexander Belyaev public:
107f89bb3c0SAlexander Belyaev using OpConversionPattern::OpConversionPattern;
108f89bb3c0SAlexander Belyaev LogicalResult
matchAndRewrite(bufferization::ToMemrefOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const109f89bb3c0SAlexander Belyaev matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
110f89bb3c0SAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
11199260e95SMatthias Springer rewriter.replaceOp(op, adaptor.getTensor());
112f89bb3c0SAlexander Belyaev return success();
113f89bb3c0SAlexander Belyaev }
114f89bb3c0SAlexander Belyaev };
115f89bb3c0SAlexander Belyaev } // namespace
116f89bb3c0SAlexander Belyaev
populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter & typeConverter,RewritePatternSet & patterns)117f89bb3c0SAlexander Belyaev void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
118f89bb3c0SAlexander Belyaev BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
119f89bb3c0SAlexander Belyaev patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
120f89bb3c0SAlexander Belyaev patterns.getContext());
121f89bb3c0SAlexander Belyaev }
122f89bb3c0SAlexander Belyaev
123f89bb3c0SAlexander Belyaev namespace {
124f89bb3c0SAlexander Belyaev struct FinalizingBufferizePass
125f89bb3c0SAlexander Belyaev : public FinalizingBufferizeBase<FinalizingBufferizePass> {
126f89bb3c0SAlexander Belyaev using FinalizingBufferizeBase<
127f89bb3c0SAlexander Belyaev FinalizingBufferizePass>::FinalizingBufferizeBase;
128f89bb3c0SAlexander Belyaev
runOnOperation__anon82ae79a50711::FinalizingBufferizePass12941574554SRiver Riddle void runOnOperation() override {
13041574554SRiver Riddle auto func = getOperation();
131f89bb3c0SAlexander Belyaev auto *context = &getContext();
132f89bb3c0SAlexander Belyaev
133f89bb3c0SAlexander Belyaev BufferizeTypeConverter typeConverter;
134f89bb3c0SAlexander Belyaev RewritePatternSet patterns(context);
135f89bb3c0SAlexander Belyaev ConversionTarget target(*context);
136f89bb3c0SAlexander Belyaev
137f89bb3c0SAlexander Belyaev populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
138f89bb3c0SAlexander Belyaev
139f89bb3c0SAlexander Belyaev // If all result types are legal, and all block arguments are legal (ensured
140f89bb3c0SAlexander Belyaev // by func conversion above), then all types in the program are legal.
141f89bb3c0SAlexander Belyaev //
142f89bb3c0SAlexander Belyaev // We also check that the operand types are legal to avoid creating invalid
143f89bb3c0SAlexander Belyaev // IR. For example, this prevents
144f89bb3c0SAlexander Belyaev // populateEliminateBufferizeMaterializationsPatterns from updating the
145f89bb3c0SAlexander Belyaev // types of the operands to a return op without updating the enclosing
146f89bb3c0SAlexander Belyaev // function.
147f89bb3c0SAlexander Belyaev target.markUnknownOpDynamicallyLegal(
148f89bb3c0SAlexander Belyaev [&](Operation *op) { return typeConverter.isLegal(op); });
149f89bb3c0SAlexander Belyaev
150f89bb3c0SAlexander Belyaev if (failed(applyFullConversion(func, target, std::move(patterns))))
151f89bb3c0SAlexander Belyaev signalPassFailure();
152f89bb3c0SAlexander Belyaev }
153f89bb3c0SAlexander Belyaev };
154d2dacde5SMatthias Springer
155f287da8aSMatthias Springer static BufferizationOptions::LayoutMapOption
parseLayoutMapOption(const std::string & s)156da3b8200SMehdi Amini parseLayoutMapOption(const std::string &s) {
157f287da8aSMatthias Springer if (s == "fully-dynamic-layout-map")
158f287da8aSMatthias Springer return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
159f287da8aSMatthias Springer if (s == "identity-layout-map")
160f287da8aSMatthias Springer return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
161f287da8aSMatthias Springer if (s == "infer-layout-map")
162f287da8aSMatthias Springer return BufferizationOptions::LayoutMapOption::InferLayoutMap;
163f287da8aSMatthias Springer llvm_unreachable("invalid layout map option");
164f287da8aSMatthias Springer }
165f287da8aSMatthias Springer
166d2dacde5SMatthias Springer struct OneShotBufferizePass
167d2dacde5SMatthias Springer : public OneShotBufferizeBase<OneShotBufferizePass> {
OneShotBufferizePass__anon82ae79a50711::OneShotBufferizePass1685c4f7494SMatthias Springer OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
169d2dacde5SMatthias Springer
OneShotBufferizePass__anon82ae79a50711::OneShotBufferizePass1709597b16aSMatthias Springer explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
171d2dacde5SMatthias Springer : options(options) {}
172d2dacde5SMatthias Springer
getDependentDialects__anon82ae79a50711::OneShotBufferizePass173d2dacde5SMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override {
174c076fa1cSMatthias Springer registry
175c076fa1cSMatthias Springer .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
176c076fa1cSMatthias Springer registerAllocationOpInterfaceExternalModels(registry);
177d2dacde5SMatthias Springer }
178d2dacde5SMatthias Springer
runOnOperation__anon82ae79a50711::OneShotBufferizePass179d2dacde5SMatthias Springer void runOnOperation() override {
1809597b16aSMatthias Springer OneShotBufferizationOptions opt;
181d2dacde5SMatthias Springer if (!options) {
182d2dacde5SMatthias Springer // Make new bufferization options if none were provided when creating the
183d2dacde5SMatthias Springer // pass.
184855a11eeSMatthias Springer opt.allowReturnAllocs = allowReturnAllocs;
185d2dacde5SMatthias Springer opt.allowUnknownOps = allowUnknownOps;
186d2dacde5SMatthias Springer opt.analysisFuzzerSeed = analysisFuzzerSeed;
187d2dacde5SMatthias Springer opt.createDeallocs = createDeallocs;
188f287da8aSMatthias Springer opt.functionBoundaryTypeConversion =
189f287da8aSMatthias Springer parseLayoutMapOption(functionBoundaryTypeConversion);
190c06f01ffSMatthias Springer if (mustInferMemorySpace)
191c06f01ffSMatthias Springer opt.defaultMemorySpace = None;
192d2dacde5SMatthias Springer opt.printConflicts = printConflicts;
193d2dacde5SMatthias Springer opt.testAnalysisOnly = testAnalysisOnly;
194d6dab38aSMatthias Springer opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
195d2dacde5SMatthias Springer
196606f7c8fSMatthias Springer // Configure type converter.
197606f7c8fSMatthias Springer BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
198606f7c8fSMatthias Springer parseLayoutMapOption(unknownTypeConversion);
199606f7c8fSMatthias Springer opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
200606f7c8fSMatthias Springer const BufferizationOptions &options) {
201606f7c8fSMatthias Springer auto tensorType = value.getType().cast<TensorType>();
202606f7c8fSMatthias Springer if (unknownTypeConversionOption ==
203606f7c8fSMatthias Springer BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
204606f7c8fSMatthias Springer return bufferization::getMemRefTypeWithStaticIdentityLayout(
205606f7c8fSMatthias Springer tensorType, memorySpace);
206606f7c8fSMatthias Springer assert(
207606f7c8fSMatthias Springer unknownTypeConversionOption ==
208606f7c8fSMatthias Springer BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
209606f7c8fSMatthias Springer "invalid layout map option");
210606f7c8fSMatthias Springer return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
211606f7c8fSMatthias Springer memorySpace);
212606f7c8fSMatthias Springer };
213606f7c8fSMatthias Springer
214606f7c8fSMatthias Springer // Configure op filter.
215*b7f93c28SJeff Niu OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
216d2dacde5SMatthias Springer // Filter may be specified via options.
217d2dacde5SMatthias Springer if (this->dialectFilter.hasValue())
218ad44495aSjacquesguan return llvm::is_contained(this->dialectFilter,
219ad44495aSjacquesguan op->getDialect()->getNamespace());
220d2dacde5SMatthias Springer // No filter specified: All other ops are allowed.
221d2dacde5SMatthias Springer return true;
222d2dacde5SMatthias Springer };
2231534177fSMatthias Springer opt.opFilter.allowOperation(filterFn);
224d2dacde5SMatthias Springer } else {
225d2dacde5SMatthias Springer opt = *options;
226d2dacde5SMatthias Springer }
227d2dacde5SMatthias Springer
228d2dacde5SMatthias Springer ModuleOp moduleOp = getOperation();
229d6dab38aSMatthias Springer if (opt.bufferizeFunctionBoundaries) {
230e07a7fd5SMatthias Springer if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
231e07a7fd5SMatthias Springer signalPassFailure();
232e07a7fd5SMatthias Springer return;
233e07a7fd5SMatthias Springer }
234e07a7fd5SMatthias Springer } else {
235d2dacde5SMatthias Springer if (failed(runOneShotBufferize(moduleOp, opt))) {
236d2dacde5SMatthias Springer signalPassFailure();
237d2dacde5SMatthias Springer return;
238d2dacde5SMatthias Springer }
239e07a7fd5SMatthias Springer }
240d2dacde5SMatthias Springer
241d2dacde5SMatthias Springer if (opt.testAnalysisOnly)
242d2dacde5SMatthias Springer return;
243d2dacde5SMatthias Springer
244d2dacde5SMatthias Springer OpPassManager cleanupPipeline("builtin.module");
245d2dacde5SMatthias Springer cleanupPipeline.addPass(createCanonicalizerPass());
246d2dacde5SMatthias Springer cleanupPipeline.addPass(createCSEPass());
247d2dacde5SMatthias Springer cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
248d2dacde5SMatthias Springer (void)runPipeline(cleanupPipeline, moduleOp);
249d2dacde5SMatthias Springer }
250d2dacde5SMatthias Springer
251d2dacde5SMatthias Springer private:
2529597b16aSMatthias Springer llvm::Optional<OneShotBufferizationOptions> options;
253d2dacde5SMatthias Springer };
254f89bb3c0SAlexander Belyaev } // namespace
255f89bb3c0SAlexander Belyaev
256ffdbecccSMatthias Springer namespace {
257ffdbecccSMatthias Springer struct BufferizationBufferizePass
258ffdbecccSMatthias Springer : public BufferizationBufferizeBase<BufferizationBufferizePass> {
runOnOperation__anon82ae79a50b11::BufferizationBufferizePass259ffdbecccSMatthias Springer void runOnOperation() override {
260ffdbecccSMatthias Springer BufferizationOptions options = getPartialBufferizationOptions();
2611534177fSMatthias Springer options.opFilter.allowDialect<BufferizationDialect>();
262ffdbecccSMatthias Springer
263ffdbecccSMatthias Springer if (failed(bufferizeOp(getOperation(), options)))
264ffdbecccSMatthias Springer signalPassFailure();
265ffdbecccSMatthias Springer }
266ffdbecccSMatthias Springer
getDependentDialects__anon82ae79a50b11::BufferizationBufferizePass267ffdbecccSMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override {
268ffdbecccSMatthias Springer registry
269ffdbecccSMatthias Springer .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
270ffdbecccSMatthias Springer }
271ffdbecccSMatthias Springer };
272ffdbecccSMatthias Springer } // namespace
273ffdbecccSMatthias Springer
createBufferizationBufferizePass()274ffdbecccSMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
275ffdbecccSMatthias Springer return std::make_unique<BufferizationBufferizePass>();
276ffdbecccSMatthias Springer }
277ffdbecccSMatthias Springer
createOneShotBufferizePass()278d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
279d2dacde5SMatthias Springer return std::make_unique<OneShotBufferizePass>();
280d2dacde5SMatthias Springer }
281d2dacde5SMatthias Springer
createOneShotBufferizePass(const OneShotBufferizationOptions & options)282d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
2839597b16aSMatthias Springer const OneShotBufferizationOptions &options) {
284d2dacde5SMatthias Springer return std::make_unique<OneShotBufferizePass>(options);
285d2dacde5SMatthias Springer }
286d2dacde5SMatthias Springer
28758ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBufferizePass()288f89bb3c0SAlexander Belyaev mlir::bufferization::createFinalizingBufferizePass() {
289f89bb3c0SAlexander Belyaev return std::make_unique<FinalizingBufferizePass>();
290f89bb3c0SAlexander Belyaev }
2917a1579acSMatthias Springer
29249e37000SMatthias Springer //===----------------------------------------------------------------------===//
29349e37000SMatthias Springer // BufferizableOpInterface-based Bufferization
29449e37000SMatthias Springer //===----------------------------------------------------------------------===//
29549e37000SMatthias Springer
isaTensor(Type t)2967a1579acSMatthias Springer static bool isaTensor(Type t) { return t.isa<TensorType>(); }
2977a1579acSMatthias Springer
2987a1579acSMatthias Springer /// Return true if the given op has a tensor result or a tensor operand.
hasTensorSemantics(Operation * op)2997a1579acSMatthias Springer static bool hasTensorSemantics(Operation *op) {
30070777d96SMatthias Springer if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
30170777d96SMatthias Springer bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
30270777d96SMatthias Springer bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
30370777d96SMatthias Springer return hasTensorArg || hasTensorResult;
30470777d96SMatthias Springer }
30570777d96SMatthias Springer
3067a1579acSMatthias Springer bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
3077a1579acSMatthias Springer bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
3087a1579acSMatthias Springer return hasTensorResult || hasTensorOperand;
3097a1579acSMatthias Springer }
3107a1579acSMatthias Springer
311d820acddSMatthias Springer namespace {
312d820acddSMatthias Springer /// A rewriter that keeps track of extra information during bufferization.
313d820acddSMatthias Springer class BufferizationRewriter : public IRRewriter {
314d820acddSMatthias Springer public:
BufferizationRewriter(MLIRContext * ctx,DenseSet<Operation * > & erasedOps,DenseSet<Operation * > & toMemrefOps,SmallVector<Operation * > & worklist,const BufferizationOptions & options,const OpFilter * opFilter)315d820acddSMatthias Springer BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
316d820acddSMatthias Springer DenseSet<Operation *> &toMemrefOps,
317b3ebe3beSMatthias Springer SmallVector<Operation *> &worklist,
3182f0a634cSMatthias Springer const BufferizationOptions &options,
3192f0a634cSMatthias Springer const OpFilter *opFilter)
320d820acddSMatthias Springer : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
321b3ebe3beSMatthias Springer worklist(worklist), analysisState(options), opFilter(opFilter) {}
322d820acddSMatthias Springer
323d820acddSMatthias Springer protected:
notifyOperationRemoved(Operation * op)324d820acddSMatthias Springer void notifyOperationRemoved(Operation *op) override {
325d820acddSMatthias Springer IRRewriter::notifyOperationRemoved(op);
326d820acddSMatthias Springer erasedOps.insert(op);
3279785eb1bSMatthias Springer // Erase if present.
3289785eb1bSMatthias Springer toMemrefOps.erase(op);
329d820acddSMatthias Springer }
330d820acddSMatthias Springer
notifyOperationInserted(Operation * op)331d820acddSMatthias Springer void notifyOperationInserted(Operation *op) override {
332d820acddSMatthias Springer IRRewriter::notifyOperationInserted(op);
333b3ebe3beSMatthias Springer erasedOps.erase(op);
334d820acddSMatthias Springer
335d820acddSMatthias Springer // Keep track of to_memref ops.
336d820acddSMatthias Springer if (isa<ToMemrefOp>(op)) {
337d820acddSMatthias Springer toMemrefOps.insert(op);
338d820acddSMatthias Springer return;
339d820acddSMatthias Springer }
340d820acddSMatthias Springer
341d820acddSMatthias Springer // Skip to_tensor ops.
342d820acddSMatthias Springer if (isa<ToTensorOp>(op))
343d820acddSMatthias Springer return;
344d820acddSMatthias Springer
3452f0a634cSMatthias Springer // Skip non-tensor ops.
3462f0a634cSMatthias Springer if (!hasTensorSemantics(op))
3472f0a634cSMatthias Springer return;
3482f0a634cSMatthias Springer
349b3ebe3beSMatthias Springer // Skip ops that are not allowed to be bufferized.
350b3ebe3beSMatthias Springer auto const &options = analysisState.getOptions();
3512f0a634cSMatthias Springer if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
3522f0a634cSMatthias Springer return;
3532f0a634cSMatthias Springer
354b3ebe3beSMatthias Springer #ifndef NDEBUG
355b3ebe3beSMatthias Springer // Read-only tensor ops may be created during bufferization. Ops that are
356b3ebe3beSMatthias Springer // writing should not be created because such ops were never analyzed.
357b3ebe3beSMatthias Springer // Bufferizing such ops could introduce a RaW conflict.
358b3ebe3beSMatthias Springer for (OpOperand &operand : op->getOpOperands())
359b3ebe3beSMatthias Springer if (operand.get().getType().isa<TensorType>())
360b3ebe3beSMatthias Springer assert(!analysisState.bufferizesToMemoryWrite(operand) &&
361b3ebe3beSMatthias Springer "creating tensor ops that bufferize to a memory write is not "
362b3ebe3beSMatthias Springer "allowed during bufferization");
363b3ebe3beSMatthias Springer #endif // NDEBUG
364b3ebe3beSMatthias Springer
365b3ebe3beSMatthias Springer // Add op to worklist.
366b3ebe3beSMatthias Springer worklist.push_back(op);
367d820acddSMatthias Springer }
368d820acddSMatthias Springer
369d820acddSMatthias Springer private:
370d820acddSMatthias Springer /// A set of all erased ops.
371d820acddSMatthias Springer DenseSet<Operation *> &erasedOps;
372d820acddSMatthias Springer
373d820acddSMatthias Springer /// A set of all to_memref ops.
374d820acddSMatthias Springer DenseSet<Operation *> &toMemrefOps;
375d820acddSMatthias Springer
376b3ebe3beSMatthias Springer /// The worklist of ops to be bufferized.
377b3ebe3beSMatthias Springer SmallVector<Operation *> &worklist;
3782f0a634cSMatthias Springer
379b3ebe3beSMatthias Springer /// The analysis state. Used for debug assertions and access to the
380b3ebe3beSMatthias Springer /// bufferization options.
381b3ebe3beSMatthias Springer const AnalysisState analysisState;
382b3ebe3beSMatthias Springer
383b3ebe3beSMatthias Springer /// An extra op filter for bufferization.
3842f0a634cSMatthias Springer const OpFilter *opFilter;
385d820acddSMatthias Springer };
386d820acddSMatthias Springer } // namespace
387d820acddSMatthias Springer
bufferizeOp(Operation * op,const BufferizationOptions & options,bool copyBeforeWrite,const OpFilter * opFilter)3882f0a634cSMatthias Springer LogicalResult bufferization::bufferizeOp(Operation *op,
389b3ebe3beSMatthias Springer const BufferizationOptions &options,
390b3ebe3beSMatthias Springer bool copyBeforeWrite,
3912f0a634cSMatthias Springer const OpFilter *opFilter) {
392b3ebe3beSMatthias Springer if (copyBeforeWrite) {
393b3ebe3beSMatthias Springer AnalysisState state(options);
394b3ebe3beSMatthias Springer if (failed(insertTensorCopies(op, state)))
395b3ebe3beSMatthias Springer return failure();
396b3ebe3beSMatthias Springer }
397b3ebe3beSMatthias Springer
398d820acddSMatthias Springer // Keep track of to_memref ops.
399d820acddSMatthias Springer DenseSet<Operation *> toMemrefOps;
400d820acddSMatthias Springer op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
401d820acddSMatthias Springer
402d820acddSMatthias Springer // Gather all bufferizable ops in top-to-bottom order.
40376b16010SMatthias Springer //
404d820acddSMatthias Springer // We should ideally know the exact memref type of all operands when
405d820acddSMatthias Springer // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
406f287da8aSMatthias Springer // Otherwise, we have to use a memref type with a fully dynamic layout map to
407f287da8aSMatthias Springer // avoid copies. We are currently missing patterns for layout maps to
408f287da8aSMatthias Springer // canonicalize away (or canonicalize to more precise layouts).
409ba9d886dSMatthias Springer //
410ba9d886dSMatthias Springer // FuncOps must be bufferized before their bodies, so add them to the worklist
411ba9d886dSMatthias Springer // first.
412d820acddSMatthias Springer SmallVector<Operation *> worklist;
413ba9d886dSMatthias Springer op->walk([&](func::FuncOp funcOp) {
414ba9d886dSMatthias Springer if (hasTensorSemantics(funcOp))
415ba9d886dSMatthias Springer worklist.push_back(funcOp);
416ba9d886dSMatthias Springer });
417ba9d886dSMatthias Springer op->walk<WalkOrder::PostOrder>([&](Operation *op) {
418ba9d886dSMatthias Springer if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
419d820acddSMatthias Springer worklist.push_back(op);
420d820acddSMatthias Springer });
4216fc753adSMatthias Springer
422d820acddSMatthias Springer // Keep track of all erased ops.
423d820acddSMatthias Springer DenseSet<Operation *> erasedOps;
4247a1579acSMatthias Springer
425d820acddSMatthias Springer // Bufferize all ops.
426d820acddSMatthias Springer BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
427b3ebe3beSMatthias Springer worklist, options, opFilter);
428d820acddSMatthias Springer for (unsigned i = 0; i < worklist.size(); ++i) {
429d820acddSMatthias Springer Operation *op = worklist[i];
430d820acddSMatthias Springer // Skip ops that were erased.
431d820acddSMatthias Springer if (erasedOps.contains(op))
432d820acddSMatthias Springer continue;
4339785eb1bSMatthias Springer // Skip ops that are not bufferizable or not allowed.
4349785eb1bSMatthias Springer auto bufferizableOp = options.dynCastBufferizableOp(op);
435d820acddSMatthias Springer if (!bufferizableOp)
436d820acddSMatthias Springer continue;
4372f0a634cSMatthias Springer if (opFilter && !opFilter->isOpAllowed(op))
4382f0a634cSMatthias Springer continue;
4399785eb1bSMatthias Springer // Skip ops that no longer have tensor semantics.
4409785eb1bSMatthias Springer if (!hasTensorSemantics(op))
441d820acddSMatthias Springer continue;
442d820acddSMatthias Springer // Bufferize the op.
443d820acddSMatthias Springer rewriter.setInsertionPoint(op);
444b55d55ecSMatthias Springer if (failed(bufferizableOp.bufferize(rewriter, options)))
4450b293bf0SMatthias Springer return op->emitError("failed to bufferize op");
446d820acddSMatthias Springer }
447d820acddSMatthias Springer
448d820acddSMatthias Springer // Fold all to_memref(to_tensor(x)) pairs.
449d820acddSMatthias Springer for (Operation *op : toMemrefOps) {
450d820acddSMatthias Springer rewriter.setInsertionPoint(op);
451d820acddSMatthias Springer (void)bufferization::foldToMemrefToTensorPair(rewriter,
452d820acddSMatthias Springer cast<ToMemrefOp>(op));
453d820acddSMatthias Springer }
454d820acddSMatthias Springer
455d820acddSMatthias Springer /// Check the result of bufferization. Return an error if an op was not
456d820acddSMatthias Springer /// bufferized, unless partial bufferization is allowed.
457b55d55ecSMatthias Springer if (options.allowUnknownOps)
458d820acddSMatthias Springer return success();
459d820acddSMatthias Springer
460d820acddSMatthias Springer for (Operation *op : worklist) {
461d820acddSMatthias Springer // Skip ops that are entirely gone.
462d820acddSMatthias Springer if (erasedOps.contains(op))
463d820acddSMatthias Springer continue;
464d820acddSMatthias Springer // Ops that no longer have tensor semantics (because they were updated
465d820acddSMatthias Springer // in-place) are allowed.
466d820acddSMatthias Springer if (!hasTensorSemantics(op))
467d820acddSMatthias Springer continue;
468d820acddSMatthias Springer // Continue ops that are not allowed.
469d820acddSMatthias Springer if (!options.isOpAllowed(op))
470d820acddSMatthias Springer continue;
4712f0a634cSMatthias Springer if (opFilter && !opFilter->isOpAllowed(op))
4722f0a634cSMatthias Springer continue;
473d820acddSMatthias Springer // Ops without any uses and no side effects will fold away.
474d820acddSMatthias Springer if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
475d820acddSMatthias Springer continue;
476b3ebe3beSMatthias Springer // ToTensorOps/ToMemrefOps are allowed in the output.
477b3ebe3beSMatthias Springer if (isa<ToTensorOp, ToMemrefOp>(op))
478b3ebe3beSMatthias Springer continue;
479d820acddSMatthias Springer return op->emitError("op was not bufferized");
480d820acddSMatthias Springer }
48105e0495fSMatthias Springer
48205e0495fSMatthias Springer return success();
4837a1579acSMatthias Springer }
484daf18108SMatthias Springer
getPartialBufferizationOptions()485cdb7675cSMatthias Springer BufferizationOptions bufferization::getPartialBufferizationOptions() {
486cdb7675cSMatthias Springer BufferizationOptions options;
487cdb7675cSMatthias Springer options.allowUnknownOps = true;
488cdb7675cSMatthias Springer options.createDeallocs = false;
489b3ebe3beSMatthias Springer options.enforceAliasingInvariants = false;
490606f7c8fSMatthias Springer options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
491606f7c8fSMatthias Springer const BufferizationOptions &options) {
492606f7c8fSMatthias Springer return getMemRefTypeWithStaticIdentityLayout(
493606f7c8fSMatthias Springer value.getType().cast<TensorType>(), memorySpace);
494606f7c8fSMatthias Springer };
495b3ebe3beSMatthias Springer options.opFilter.allowDialect<BufferizationDialect>();
496daf18108SMatthias Springer return options;
497daf18108SMatthias Springer }
498