1f89bb3c0SAlexander Belyaev //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2f89bb3c0SAlexander Belyaev //
3f89bb3c0SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f89bb3c0SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5f89bb3c0SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f89bb3c0SAlexander Belyaev //
7f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
8f89bb3c0SAlexander Belyaev 
9f89bb3c0SAlexander Belyaev #include "PassDetail.h"
10f89bb3c0SAlexander Belyaev 
117a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14d2dacde5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
19eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
20f89bb3c0SAlexander Belyaev #include "mlir/IR/Operation.h"
21d2dacde5SMatthias Springer #include "mlir/Pass/PassManager.h"
227a1579acSMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23d2dacde5SMatthias Springer #include "mlir/Transforms/Passes.h"
24f89bb3c0SAlexander Belyaev 
25f89bb3c0SAlexander Belyaev using namespace mlir;
26f89bb3c0SAlexander Belyaev using namespace mlir::bufferization;
27f89bb3c0SAlexander Belyaev 
28f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
29f89bb3c0SAlexander Belyaev // BufferizeTypeConverter
30f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===//
31f89bb3c0SAlexander Belyaev 
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)32f89bb3c0SAlexander Belyaev static Value materializeToTensor(OpBuilder &builder, TensorType type,
33f89bb3c0SAlexander Belyaev                                  ValueRange inputs, Location loc) {
34f89bb3c0SAlexander Belyaev   assert(inputs.size() == 1);
35f89bb3c0SAlexander Belyaev   assert(inputs[0].getType().isa<BaseMemRefType>());
36f89bb3c0SAlexander Belyaev   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
37f89bb3c0SAlexander Belyaev }
38f89bb3c0SAlexander Belyaev 
39f89bb3c0SAlexander Belyaev /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()40f89bb3c0SAlexander Belyaev BufferizeTypeConverter::BufferizeTypeConverter() {
41f89bb3c0SAlexander Belyaev   // Keep all types unchanged.
42f89bb3c0SAlexander Belyaev   addConversion([](Type type) { return type; });
43f89bb3c0SAlexander Belyaev   // Convert RankedTensorType to MemRefType.
44f89bb3c0SAlexander Belyaev   addConversion([](RankedTensorType type) -> Type {
45f89bb3c0SAlexander Belyaev     return MemRefType::get(type.getShape(), type.getElementType());
46f89bb3c0SAlexander Belyaev   });
47f89bb3c0SAlexander Belyaev   // Convert UnrankedTensorType to UnrankedMemRefType.
48f89bb3c0SAlexander Belyaev   addConversion([](UnrankedTensorType type) -> Type {
49f89bb3c0SAlexander Belyaev     return UnrankedMemRefType::get(type.getElementType(), 0);
50f89bb3c0SAlexander Belyaev   });
51f89bb3c0SAlexander Belyaev   addArgumentMaterialization(materializeToTensor);
52f89bb3c0SAlexander Belyaev   addSourceMaterialization(materializeToTensor);
53f89bb3c0SAlexander Belyaev   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
54f89bb3c0SAlexander Belyaev                               ValueRange inputs, Location loc) -> Value {
55fa7c8cb4SMatthias Springer     assert(inputs.size() == 1 && "expected exactly one input");
56fa7c8cb4SMatthias Springer 
57fa7c8cb4SMatthias Springer     if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
58fa7c8cb4SMatthias Springer       // MemRef to MemRef cast.
59fa7c8cb4SMatthias Springer       assert(inputType != type && "expected different types");
60fa7c8cb4SMatthias Springer       // Unranked to ranked and ranked to unranked casts must be explicit.
61fa7c8cb4SMatthias Springer       auto rankedDestType = type.dyn_cast<MemRefType>();
62fa7c8cb4SMatthias Springer       if (!rankedDestType)
63fa7c8cb4SMatthias Springer         return nullptr;
64fa7c8cb4SMatthias Springer       FailureOr<Value> replacement =
65fa7c8cb4SMatthias Springer           castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
66fa7c8cb4SMatthias Springer       if (failed(replacement))
67fa7c8cb4SMatthias Springer         return nullptr;
68fa7c8cb4SMatthias Springer       return *replacement;
69fa7c8cb4SMatthias Springer     }
70fa7c8cb4SMatthias Springer 
71fa7c8cb4SMatthias Springer     if (inputs[0].getType().isa<TensorType>()) {
72fa7c8cb4SMatthias Springer       // Tensor to MemRef cast.
73f89bb3c0SAlexander Belyaev       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
74fa7c8cb4SMatthias Springer     }
75fa7c8cb4SMatthias Springer 
76fa7c8cb4SMatthias Springer     llvm_unreachable("only tensor/memref input types supported");
77f89bb3c0SAlexander Belyaev   });
78f89bb3c0SAlexander Belyaev }
79f89bb3c0SAlexander Belyaev 
populateBufferizeMaterializationLegality(ConversionTarget & target)80f89bb3c0SAlexander Belyaev void mlir::bufferization::populateBufferizeMaterializationLegality(
81f89bb3c0SAlexander Belyaev     ConversionTarget &target) {
82f89bb3c0SAlexander Belyaev   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
83f89bb3c0SAlexander Belyaev }
84f89bb3c0SAlexander Belyaev 
85f89bb3c0SAlexander Belyaev namespace {
86f89bb3c0SAlexander Belyaev // In a finalizing bufferize conversion, we know that all tensors have been
87f89bb3c0SAlexander Belyaev // converted to memrefs, thus, this op becomes an identity.
88f89bb3c0SAlexander Belyaev class BufferizeToTensorOp
89f89bb3c0SAlexander Belyaev     : public OpConversionPattern<bufferization::ToTensorOp> {
90f89bb3c0SAlexander Belyaev public:
91f89bb3c0SAlexander Belyaev   using OpConversionPattern::OpConversionPattern;
92f89bb3c0SAlexander Belyaev   LogicalResult
matchAndRewrite(bufferization::ToTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const93f89bb3c0SAlexander Belyaev   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
94f89bb3c0SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
9599260e95SMatthias Springer     rewriter.replaceOp(op, adaptor.getMemref());
96f89bb3c0SAlexander Belyaev     return success();
97f89bb3c0SAlexander Belyaev   }
98f89bb3c0SAlexander Belyaev };
99f89bb3c0SAlexander Belyaev } // namespace
100f89bb3c0SAlexander Belyaev 
101f89bb3c0SAlexander Belyaev namespace {
102f89bb3c0SAlexander Belyaev // In a finalizing bufferize conversion, we know that all tensors have been
103f89bb3c0SAlexander Belyaev // converted to memrefs, thus, this op becomes an identity.
104f89bb3c0SAlexander Belyaev class BufferizeToMemrefOp
105f89bb3c0SAlexander Belyaev     : public OpConversionPattern<bufferization::ToMemrefOp> {
106f89bb3c0SAlexander Belyaev public:
107f89bb3c0SAlexander Belyaev   using OpConversionPattern::OpConversionPattern;
108f89bb3c0SAlexander Belyaev   LogicalResult
matchAndRewrite(bufferization::ToMemrefOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const109f89bb3c0SAlexander Belyaev   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
110f89bb3c0SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
11199260e95SMatthias Springer     rewriter.replaceOp(op, adaptor.getTensor());
112f89bb3c0SAlexander Belyaev     return success();
113f89bb3c0SAlexander Belyaev   }
114f89bb3c0SAlexander Belyaev };
115f89bb3c0SAlexander Belyaev } // namespace
116f89bb3c0SAlexander Belyaev 
populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter & typeConverter,RewritePatternSet & patterns)117f89bb3c0SAlexander Belyaev void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
118f89bb3c0SAlexander Belyaev     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
119f89bb3c0SAlexander Belyaev   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
120f89bb3c0SAlexander Belyaev                                                          patterns.getContext());
121f89bb3c0SAlexander Belyaev }
122f89bb3c0SAlexander Belyaev 
123f89bb3c0SAlexander Belyaev namespace {
124f89bb3c0SAlexander Belyaev struct FinalizingBufferizePass
125f89bb3c0SAlexander Belyaev     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
126f89bb3c0SAlexander Belyaev   using FinalizingBufferizeBase<
127f89bb3c0SAlexander Belyaev       FinalizingBufferizePass>::FinalizingBufferizeBase;
128f89bb3c0SAlexander Belyaev 
runOnOperation__anon82ae79a50711::FinalizingBufferizePass12941574554SRiver Riddle   void runOnOperation() override {
13041574554SRiver Riddle     auto func = getOperation();
131f89bb3c0SAlexander Belyaev     auto *context = &getContext();
132f89bb3c0SAlexander Belyaev 
133f89bb3c0SAlexander Belyaev     BufferizeTypeConverter typeConverter;
134f89bb3c0SAlexander Belyaev     RewritePatternSet patterns(context);
135f89bb3c0SAlexander Belyaev     ConversionTarget target(*context);
136f89bb3c0SAlexander Belyaev 
137f89bb3c0SAlexander Belyaev     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
138f89bb3c0SAlexander Belyaev 
139f89bb3c0SAlexander Belyaev     // If all result types are legal, and all block arguments are legal (ensured
140f89bb3c0SAlexander Belyaev     // by func conversion above), then all types in the program are legal.
141f89bb3c0SAlexander Belyaev     //
142f89bb3c0SAlexander Belyaev     // We also check that the operand types are legal to avoid creating invalid
143f89bb3c0SAlexander Belyaev     // IR. For example, this prevents
144f89bb3c0SAlexander Belyaev     // populateEliminateBufferizeMaterializationsPatterns from updating the
145f89bb3c0SAlexander Belyaev     // types of the operands to a return op without updating the enclosing
146f89bb3c0SAlexander Belyaev     // function.
147f89bb3c0SAlexander Belyaev     target.markUnknownOpDynamicallyLegal(
148f89bb3c0SAlexander Belyaev         [&](Operation *op) { return typeConverter.isLegal(op); });
149f89bb3c0SAlexander Belyaev 
150f89bb3c0SAlexander Belyaev     if (failed(applyFullConversion(func, target, std::move(patterns))))
151f89bb3c0SAlexander Belyaev       signalPassFailure();
152f89bb3c0SAlexander Belyaev   }
153f89bb3c0SAlexander Belyaev };
154d2dacde5SMatthias Springer 
155f287da8aSMatthias Springer static BufferizationOptions::LayoutMapOption
parseLayoutMapOption(const std::string & s)156da3b8200SMehdi Amini parseLayoutMapOption(const std::string &s) {
157f287da8aSMatthias Springer   if (s == "fully-dynamic-layout-map")
158f287da8aSMatthias Springer     return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
159f287da8aSMatthias Springer   if (s == "identity-layout-map")
160f287da8aSMatthias Springer     return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
161f287da8aSMatthias Springer   if (s == "infer-layout-map")
162f287da8aSMatthias Springer     return BufferizationOptions::LayoutMapOption::InferLayoutMap;
163f287da8aSMatthias Springer   llvm_unreachable("invalid layout map option");
164f287da8aSMatthias Springer }
165f287da8aSMatthias Springer 
166d2dacde5SMatthias Springer struct OneShotBufferizePass
167d2dacde5SMatthias Springer     : public OneShotBufferizeBase<OneShotBufferizePass> {
OneShotBufferizePass__anon82ae79a50711::OneShotBufferizePass1685c4f7494SMatthias Springer   OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
169d2dacde5SMatthias Springer 
OneShotBufferizePass__anon82ae79a50711::OneShotBufferizePass1709597b16aSMatthias Springer   explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
171d2dacde5SMatthias Springer       : options(options) {}
172d2dacde5SMatthias Springer 
getDependentDialects__anon82ae79a50711::OneShotBufferizePass173d2dacde5SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
174c076fa1cSMatthias Springer     registry
175c076fa1cSMatthias Springer         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
176c076fa1cSMatthias Springer     registerAllocationOpInterfaceExternalModels(registry);
177d2dacde5SMatthias Springer   }
178d2dacde5SMatthias Springer 
runOnOperation__anon82ae79a50711::OneShotBufferizePass179d2dacde5SMatthias Springer   void runOnOperation() override {
1809597b16aSMatthias Springer     OneShotBufferizationOptions opt;
181d2dacde5SMatthias Springer     if (!options) {
182d2dacde5SMatthias Springer       // Make new bufferization options if none were provided when creating the
183d2dacde5SMatthias Springer       // pass.
184855a11eeSMatthias Springer       opt.allowReturnAllocs = allowReturnAllocs;
185d2dacde5SMatthias Springer       opt.allowUnknownOps = allowUnknownOps;
186d2dacde5SMatthias Springer       opt.analysisFuzzerSeed = analysisFuzzerSeed;
187d2dacde5SMatthias Springer       opt.createDeallocs = createDeallocs;
188f287da8aSMatthias Springer       opt.functionBoundaryTypeConversion =
189f287da8aSMatthias Springer           parseLayoutMapOption(functionBoundaryTypeConversion);
190c06f01ffSMatthias Springer       if (mustInferMemorySpace)
191c06f01ffSMatthias Springer         opt.defaultMemorySpace = None;
192d2dacde5SMatthias Springer       opt.printConflicts = printConflicts;
193d2dacde5SMatthias Springer       opt.testAnalysisOnly = testAnalysisOnly;
194d6dab38aSMatthias Springer       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
195d2dacde5SMatthias Springer 
196606f7c8fSMatthias Springer       // Configure type converter.
197606f7c8fSMatthias Springer       BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
198606f7c8fSMatthias Springer           parseLayoutMapOption(unknownTypeConversion);
199606f7c8fSMatthias Springer       opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
200606f7c8fSMatthias Springer                                        const BufferizationOptions &options) {
201606f7c8fSMatthias Springer         auto tensorType = value.getType().cast<TensorType>();
202606f7c8fSMatthias Springer         if (unknownTypeConversionOption ==
203606f7c8fSMatthias Springer             BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
204606f7c8fSMatthias Springer           return bufferization::getMemRefTypeWithStaticIdentityLayout(
205606f7c8fSMatthias Springer               tensorType, memorySpace);
206606f7c8fSMatthias Springer         assert(
207606f7c8fSMatthias Springer             unknownTypeConversionOption ==
208606f7c8fSMatthias Springer                 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
209606f7c8fSMatthias Springer             "invalid layout map option");
210606f7c8fSMatthias Springer         return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
211606f7c8fSMatthias Springer                                                                   memorySpace);
212606f7c8fSMatthias Springer       };
213606f7c8fSMatthias Springer 
214606f7c8fSMatthias Springer       // Configure op filter.
215*b7f93c28SJeff Niu       OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
216d2dacde5SMatthias Springer         // Filter may be specified via options.
217d2dacde5SMatthias Springer         if (this->dialectFilter.hasValue())
218ad44495aSjacquesguan           return llvm::is_contained(this->dialectFilter,
219ad44495aSjacquesguan                                     op->getDialect()->getNamespace());
220d2dacde5SMatthias Springer         // No filter specified: All other ops are allowed.
221d2dacde5SMatthias Springer         return true;
222d2dacde5SMatthias Springer       };
2231534177fSMatthias Springer       opt.opFilter.allowOperation(filterFn);
224d2dacde5SMatthias Springer     } else {
225d2dacde5SMatthias Springer       opt = *options;
226d2dacde5SMatthias Springer     }
227d2dacde5SMatthias Springer 
228d2dacde5SMatthias Springer     ModuleOp moduleOp = getOperation();
229d6dab38aSMatthias Springer     if (opt.bufferizeFunctionBoundaries) {
230e07a7fd5SMatthias Springer       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
231e07a7fd5SMatthias Springer         signalPassFailure();
232e07a7fd5SMatthias Springer         return;
233e07a7fd5SMatthias Springer       }
234e07a7fd5SMatthias Springer     } else {
235d2dacde5SMatthias Springer       if (failed(runOneShotBufferize(moduleOp, opt))) {
236d2dacde5SMatthias Springer         signalPassFailure();
237d2dacde5SMatthias Springer         return;
238d2dacde5SMatthias Springer       }
239e07a7fd5SMatthias Springer     }
240d2dacde5SMatthias Springer 
241d2dacde5SMatthias Springer     if (opt.testAnalysisOnly)
242d2dacde5SMatthias Springer       return;
243d2dacde5SMatthias Springer 
244d2dacde5SMatthias Springer     OpPassManager cleanupPipeline("builtin.module");
245d2dacde5SMatthias Springer     cleanupPipeline.addPass(createCanonicalizerPass());
246d2dacde5SMatthias Springer     cleanupPipeline.addPass(createCSEPass());
247d2dacde5SMatthias Springer     cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
248d2dacde5SMatthias Springer     (void)runPipeline(cleanupPipeline, moduleOp);
249d2dacde5SMatthias Springer   }
250d2dacde5SMatthias Springer 
251d2dacde5SMatthias Springer private:
2529597b16aSMatthias Springer   llvm::Optional<OneShotBufferizationOptions> options;
253d2dacde5SMatthias Springer };
254f89bb3c0SAlexander Belyaev } // namespace
255f89bb3c0SAlexander Belyaev 
256ffdbecccSMatthias Springer namespace {
257ffdbecccSMatthias Springer struct BufferizationBufferizePass
258ffdbecccSMatthias Springer     : public BufferizationBufferizeBase<BufferizationBufferizePass> {
runOnOperation__anon82ae79a50b11::BufferizationBufferizePass259ffdbecccSMatthias Springer   void runOnOperation() override {
260ffdbecccSMatthias Springer     BufferizationOptions options = getPartialBufferizationOptions();
2611534177fSMatthias Springer     options.opFilter.allowDialect<BufferizationDialect>();
262ffdbecccSMatthias Springer 
263ffdbecccSMatthias Springer     if (failed(bufferizeOp(getOperation(), options)))
264ffdbecccSMatthias Springer       signalPassFailure();
265ffdbecccSMatthias Springer   }
266ffdbecccSMatthias Springer 
getDependentDialects__anon82ae79a50b11::BufferizationBufferizePass267ffdbecccSMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
268ffdbecccSMatthias Springer     registry
269ffdbecccSMatthias Springer         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
270ffdbecccSMatthias Springer   }
271ffdbecccSMatthias Springer };
272ffdbecccSMatthias Springer } // namespace
273ffdbecccSMatthias Springer 
createBufferizationBufferizePass()274ffdbecccSMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
275ffdbecccSMatthias Springer   return std::make_unique<BufferizationBufferizePass>();
276ffdbecccSMatthias Springer }
277ffdbecccSMatthias Springer 
createOneShotBufferizePass()278d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
279d2dacde5SMatthias Springer   return std::make_unique<OneShotBufferizePass>();
280d2dacde5SMatthias Springer }
281d2dacde5SMatthias Springer 
createOneShotBufferizePass(const OneShotBufferizationOptions & options)282d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
2839597b16aSMatthias Springer     const OneShotBufferizationOptions &options) {
284d2dacde5SMatthias Springer   return std::make_unique<OneShotBufferizePass>(options);
285d2dacde5SMatthias Springer }
286d2dacde5SMatthias Springer 
28758ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBufferizePass()288f89bb3c0SAlexander Belyaev mlir::bufferization::createFinalizingBufferizePass() {
289f89bb3c0SAlexander Belyaev   return std::make_unique<FinalizingBufferizePass>();
290f89bb3c0SAlexander Belyaev }
2917a1579acSMatthias Springer 
29249e37000SMatthias Springer //===----------------------------------------------------------------------===//
29349e37000SMatthias Springer // BufferizableOpInterface-based Bufferization
29449e37000SMatthias Springer //===----------------------------------------------------------------------===//
29549e37000SMatthias Springer 
isaTensor(Type t)2967a1579acSMatthias Springer static bool isaTensor(Type t) { return t.isa<TensorType>(); }
2977a1579acSMatthias Springer 
2987a1579acSMatthias Springer /// Return true if the given op has a tensor result or a tensor operand.
hasTensorSemantics(Operation * op)2997a1579acSMatthias Springer static bool hasTensorSemantics(Operation *op) {
30070777d96SMatthias Springer   if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
30170777d96SMatthias Springer     bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
30270777d96SMatthias Springer     bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
30370777d96SMatthias Springer     return hasTensorArg || hasTensorResult;
30470777d96SMatthias Springer   }
30570777d96SMatthias Springer 
3067a1579acSMatthias Springer   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
3077a1579acSMatthias Springer   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
3087a1579acSMatthias Springer   return hasTensorResult || hasTensorOperand;
3097a1579acSMatthias Springer }
3107a1579acSMatthias Springer 
311d820acddSMatthias Springer namespace {
312d820acddSMatthias Springer /// A rewriter that keeps track of extra information during bufferization.
313d820acddSMatthias Springer class BufferizationRewriter : public IRRewriter {
314d820acddSMatthias Springer public:
BufferizationRewriter(MLIRContext * ctx,DenseSet<Operation * > & erasedOps,DenseSet<Operation * > & toMemrefOps,SmallVector<Operation * > & worklist,const BufferizationOptions & options,const OpFilter * opFilter)315d820acddSMatthias Springer   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
316d820acddSMatthias Springer                         DenseSet<Operation *> &toMemrefOps,
317b3ebe3beSMatthias Springer                         SmallVector<Operation *> &worklist,
3182f0a634cSMatthias Springer                         const BufferizationOptions &options,
3192f0a634cSMatthias Springer                         const OpFilter *opFilter)
320d820acddSMatthias Springer       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
321b3ebe3beSMatthias Springer         worklist(worklist), analysisState(options), opFilter(opFilter) {}
322d820acddSMatthias Springer 
323d820acddSMatthias Springer protected:
notifyOperationRemoved(Operation * op)324d820acddSMatthias Springer   void notifyOperationRemoved(Operation *op) override {
325d820acddSMatthias Springer     IRRewriter::notifyOperationRemoved(op);
326d820acddSMatthias Springer     erasedOps.insert(op);
3279785eb1bSMatthias Springer     // Erase if present.
3289785eb1bSMatthias Springer     toMemrefOps.erase(op);
329d820acddSMatthias Springer   }
330d820acddSMatthias Springer 
notifyOperationInserted(Operation * op)331d820acddSMatthias Springer   void notifyOperationInserted(Operation *op) override {
332d820acddSMatthias Springer     IRRewriter::notifyOperationInserted(op);
333b3ebe3beSMatthias Springer     erasedOps.erase(op);
334d820acddSMatthias Springer 
335d820acddSMatthias Springer     // Keep track of to_memref ops.
336d820acddSMatthias Springer     if (isa<ToMemrefOp>(op)) {
337d820acddSMatthias Springer       toMemrefOps.insert(op);
338d820acddSMatthias Springer       return;
339d820acddSMatthias Springer     }
340d820acddSMatthias Springer 
341d820acddSMatthias Springer     // Skip to_tensor ops.
342d820acddSMatthias Springer     if (isa<ToTensorOp>(op))
343d820acddSMatthias Springer       return;
344d820acddSMatthias Springer 
3452f0a634cSMatthias Springer     // Skip non-tensor ops.
3462f0a634cSMatthias Springer     if (!hasTensorSemantics(op))
3472f0a634cSMatthias Springer       return;
3482f0a634cSMatthias Springer 
349b3ebe3beSMatthias Springer     // Skip ops that are not allowed to be bufferized.
350b3ebe3beSMatthias Springer     auto const &options = analysisState.getOptions();
3512f0a634cSMatthias Springer     if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
3522f0a634cSMatthias Springer       return;
3532f0a634cSMatthias Springer 
354b3ebe3beSMatthias Springer #ifndef NDEBUG
355b3ebe3beSMatthias Springer     // Read-only tensor ops may be created during bufferization. Ops that are
356b3ebe3beSMatthias Springer     // writing should not be created because such ops were never analyzed.
357b3ebe3beSMatthias Springer     // Bufferizing such ops could introduce a RaW conflict.
358b3ebe3beSMatthias Springer     for (OpOperand &operand : op->getOpOperands())
359b3ebe3beSMatthias Springer       if (operand.get().getType().isa<TensorType>())
360b3ebe3beSMatthias Springer         assert(!analysisState.bufferizesToMemoryWrite(operand) &&
361b3ebe3beSMatthias Springer                "creating tensor ops that bufferize to a memory write is not "
362b3ebe3beSMatthias Springer                "allowed during bufferization");
363b3ebe3beSMatthias Springer #endif // NDEBUG
364b3ebe3beSMatthias Springer 
365b3ebe3beSMatthias Springer     // Add op to worklist.
366b3ebe3beSMatthias Springer     worklist.push_back(op);
367d820acddSMatthias Springer   }
368d820acddSMatthias Springer 
369d820acddSMatthias Springer private:
370d820acddSMatthias Springer   /// A set of all erased ops.
371d820acddSMatthias Springer   DenseSet<Operation *> &erasedOps;
372d820acddSMatthias Springer 
373d820acddSMatthias Springer   /// A set of all to_memref ops.
374d820acddSMatthias Springer   DenseSet<Operation *> &toMemrefOps;
375d820acddSMatthias Springer 
376b3ebe3beSMatthias Springer   /// The worklist of ops to be bufferized.
377b3ebe3beSMatthias Springer   SmallVector<Operation *> &worklist;
3782f0a634cSMatthias Springer 
379b3ebe3beSMatthias Springer   /// The analysis state. Used for debug assertions and access to the
380b3ebe3beSMatthias Springer   /// bufferization options.
381b3ebe3beSMatthias Springer   const AnalysisState analysisState;
382b3ebe3beSMatthias Springer 
383b3ebe3beSMatthias Springer   /// An extra op filter for bufferization.
3842f0a634cSMatthias Springer   const OpFilter *opFilter;
385d820acddSMatthias Springer };
386d820acddSMatthias Springer } // namespace
387d820acddSMatthias Springer 
bufferizeOp(Operation * op,const BufferizationOptions & options,bool copyBeforeWrite,const OpFilter * opFilter)3882f0a634cSMatthias Springer LogicalResult bufferization::bufferizeOp(Operation *op,
389b3ebe3beSMatthias Springer                                          const BufferizationOptions &options,
390b3ebe3beSMatthias Springer                                          bool copyBeforeWrite,
3912f0a634cSMatthias Springer                                          const OpFilter *opFilter) {
392b3ebe3beSMatthias Springer   if (copyBeforeWrite) {
393b3ebe3beSMatthias Springer     AnalysisState state(options);
394b3ebe3beSMatthias Springer     if (failed(insertTensorCopies(op, state)))
395b3ebe3beSMatthias Springer       return failure();
396b3ebe3beSMatthias Springer   }
397b3ebe3beSMatthias Springer 
398d820acddSMatthias Springer   // Keep track of to_memref ops.
399d820acddSMatthias Springer   DenseSet<Operation *> toMemrefOps;
400d820acddSMatthias Springer   op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
401d820acddSMatthias Springer 
402d820acddSMatthias Springer   // Gather all bufferizable ops in top-to-bottom order.
40376b16010SMatthias Springer   //
404d820acddSMatthias Springer   // We should ideally know the exact memref type of all operands when
405d820acddSMatthias Springer   // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
406f287da8aSMatthias Springer   // Otherwise, we have to use a memref type with a fully dynamic layout map to
407f287da8aSMatthias Springer   // avoid copies. We are currently missing patterns for layout maps to
408f287da8aSMatthias Springer   // canonicalize away (or canonicalize to more precise layouts).
409ba9d886dSMatthias Springer   //
410ba9d886dSMatthias Springer   // FuncOps must be bufferized before their bodies, so add them to the worklist
411ba9d886dSMatthias Springer   // first.
412d820acddSMatthias Springer   SmallVector<Operation *> worklist;
413ba9d886dSMatthias Springer   op->walk([&](func::FuncOp funcOp) {
414ba9d886dSMatthias Springer     if (hasTensorSemantics(funcOp))
415ba9d886dSMatthias Springer       worklist.push_back(funcOp);
416ba9d886dSMatthias Springer   });
417ba9d886dSMatthias Springer   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
418ba9d886dSMatthias Springer     if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
419d820acddSMatthias Springer       worklist.push_back(op);
420d820acddSMatthias Springer   });
4216fc753adSMatthias Springer 
422d820acddSMatthias Springer   // Keep track of all erased ops.
423d820acddSMatthias Springer   DenseSet<Operation *> erasedOps;
4247a1579acSMatthias Springer 
425d820acddSMatthias Springer   // Bufferize all ops.
426d820acddSMatthias Springer   BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
427b3ebe3beSMatthias Springer                                  worklist, options, opFilter);
428d820acddSMatthias Springer   for (unsigned i = 0; i < worklist.size(); ++i) {
429d820acddSMatthias Springer     Operation *op = worklist[i];
430d820acddSMatthias Springer     // Skip ops that were erased.
431d820acddSMatthias Springer     if (erasedOps.contains(op))
432d820acddSMatthias Springer       continue;
4339785eb1bSMatthias Springer     // Skip ops that are not bufferizable or not allowed.
4349785eb1bSMatthias Springer     auto bufferizableOp = options.dynCastBufferizableOp(op);
435d820acddSMatthias Springer     if (!bufferizableOp)
436d820acddSMatthias Springer       continue;
4372f0a634cSMatthias Springer     if (opFilter && !opFilter->isOpAllowed(op))
4382f0a634cSMatthias Springer       continue;
4399785eb1bSMatthias Springer     // Skip ops that no longer have tensor semantics.
4409785eb1bSMatthias Springer     if (!hasTensorSemantics(op))
441d820acddSMatthias Springer       continue;
442d820acddSMatthias Springer     // Bufferize the op.
443d820acddSMatthias Springer     rewriter.setInsertionPoint(op);
444b55d55ecSMatthias Springer     if (failed(bufferizableOp.bufferize(rewriter, options)))
4450b293bf0SMatthias Springer       return op->emitError("failed to bufferize op");
446d820acddSMatthias Springer   }
447d820acddSMatthias Springer 
448d820acddSMatthias Springer   // Fold all to_memref(to_tensor(x)) pairs.
449d820acddSMatthias Springer   for (Operation *op : toMemrefOps) {
450d820acddSMatthias Springer     rewriter.setInsertionPoint(op);
451d820acddSMatthias Springer     (void)bufferization::foldToMemrefToTensorPair(rewriter,
452d820acddSMatthias Springer                                                   cast<ToMemrefOp>(op));
453d820acddSMatthias Springer   }
454d820acddSMatthias Springer 
455d820acddSMatthias Springer   /// Check the result of bufferization. Return an error if an op was not
456d820acddSMatthias Springer   /// bufferized, unless partial bufferization is allowed.
457b55d55ecSMatthias Springer   if (options.allowUnknownOps)
458d820acddSMatthias Springer     return success();
459d820acddSMatthias Springer 
460d820acddSMatthias Springer   for (Operation *op : worklist) {
461d820acddSMatthias Springer     // Skip ops that are entirely gone.
462d820acddSMatthias Springer     if (erasedOps.contains(op))
463d820acddSMatthias Springer       continue;
464d820acddSMatthias Springer     // Ops that no longer have tensor semantics (because they were updated
465d820acddSMatthias Springer     // in-place) are allowed.
466d820acddSMatthias Springer     if (!hasTensorSemantics(op))
467d820acddSMatthias Springer       continue;
468d820acddSMatthias Springer     // Continue ops that are not allowed.
469d820acddSMatthias Springer     if (!options.isOpAllowed(op))
470d820acddSMatthias Springer       continue;
4712f0a634cSMatthias Springer     if (opFilter && !opFilter->isOpAllowed(op))
4722f0a634cSMatthias Springer       continue;
473d820acddSMatthias Springer     // Ops without any uses and no side effects will fold away.
474d820acddSMatthias Springer     if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
475d820acddSMatthias Springer       continue;
476b3ebe3beSMatthias Springer     // ToTensorOps/ToMemrefOps are allowed in the output.
477b3ebe3beSMatthias Springer     if (isa<ToTensorOp, ToMemrefOp>(op))
478b3ebe3beSMatthias Springer       continue;
479d820acddSMatthias Springer     return op->emitError("op was not bufferized");
480d820acddSMatthias Springer   }
48105e0495fSMatthias Springer 
48205e0495fSMatthias Springer   return success();
4837a1579acSMatthias Springer }
484daf18108SMatthias Springer 
getPartialBufferizationOptions()485cdb7675cSMatthias Springer BufferizationOptions bufferization::getPartialBufferizationOptions() {
486cdb7675cSMatthias Springer   BufferizationOptions options;
487cdb7675cSMatthias Springer   options.allowUnknownOps = true;
488cdb7675cSMatthias Springer   options.createDeallocs = false;
489b3ebe3beSMatthias Springer   options.enforceAliasingInvariants = false;
490606f7c8fSMatthias Springer   options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
491606f7c8fSMatthias Springer                                       const BufferizationOptions &options) {
492606f7c8fSMatthias Springer     return getMemRefTypeWithStaticIdentityLayout(
493606f7c8fSMatthias Springer         value.getType().cast<TensorType>(), memorySpace);
494606f7c8fSMatthias Springer   };
495b3ebe3beSMatthias Springer   options.opFilter.allowDialect<BufferizationDialect>();
496daf18108SMatthias Springer   return options;
497daf18108SMatthias Springer }
498