1461dafd2SMatthias Springer //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2461dafd2SMatthias Springer //
3461dafd2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4461dafd2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5461dafd2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6461dafd2SMatthias Springer //
7461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
8461dafd2SMatthias Springer 
9461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10461dafd2SMatthias Springer 
11461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14461dafd2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
15461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDL.h"
16461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17461dafd2SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18461dafd2SMatthias Springer 
19461dafd2SMatthias Springer using namespace mlir;
20461dafd2SMatthias Springer using namespace mlir::bufferization;
21461dafd2SMatthias Springer using namespace mlir::transform;
22461dafd2SMatthias Springer 
23461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
24461dafd2SMatthias Springer // OneShotBufferizeOp
25461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
26461dafd2SMatthias Springer 
271d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(TransformResults & transformResults,TransformState & state)28461dafd2SMatthias Springer transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
29461dafd2SMatthias Springer                                      TransformState &state) {
30461dafd2SMatthias Springer   OneShotBufferizationOptions options;
31461dafd2SMatthias Springer   options.allowReturnAllocs = getAllowReturnAllocs();
32461dafd2SMatthias Springer   options.allowUnknownOps = getAllowUnknownOps();
33461dafd2SMatthias Springer   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
34461dafd2SMatthias Springer   options.createDeallocs = getCreateDeallocs();
35461dafd2SMatthias Springer   options.testAnalysisOnly = getTestAnalysisOnly();
36461dafd2SMatthias Springer   options.printConflicts = getPrintConflicts();
37461dafd2SMatthias Springer 
38461dafd2SMatthias Springer   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
39461dafd2SMatthias Springer   for (Operation *target : payloadOps) {
40461dafd2SMatthias Springer     auto moduleOp = dyn_cast<ModuleOp>(target);
41461dafd2SMatthias Springer     if (getTargetIsModule() && !moduleOp)
421d45282aSAlex Zinenko       return emitSilenceableError() << "expected ModuleOp target";
43461dafd2SMatthias Springer     if (options.bufferizeFunctionBoundaries) {
44461dafd2SMatthias Springer       if (!moduleOp)
451d45282aSAlex Zinenko         return emitSilenceableError() << "expected ModuleOp target";
46461dafd2SMatthias Springer       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
471d45282aSAlex Zinenko         return emitSilenceableError() << "bufferization failed";
48461dafd2SMatthias Springer     } else {
49461dafd2SMatthias Springer       if (failed(bufferization::runOneShotBufferize(target, options)))
501d45282aSAlex Zinenko         return emitSilenceableError() << "bufferization failed";
51461dafd2SMatthias Springer     }
52461dafd2SMatthias Springer   }
53461dafd2SMatthias Springer 
541d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
55461dafd2SMatthias Springer }
56461dafd2SMatthias Springer 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)57461dafd2SMatthias Springer void transform::OneShotBufferizeOp::getEffects(
58461dafd2SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
59461dafd2SMatthias Springer   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
60461dafd2SMatthias Springer                        TransformMappingResource::get());
61461dafd2SMatthias Springer 
62461dafd2SMatthias Springer   // Handles that are not modules are not longer usable.
63461dafd2SMatthias Springer   if (!getTargetIsModule())
64461dafd2SMatthias Springer     effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
65461dafd2SMatthias Springer                          TransformMappingResource::get());
66461dafd2SMatthias Springer }
67461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
68461dafd2SMatthias Springer // Transform op registration
69461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
70461dafd2SMatthias Springer 
71461dafd2SMatthias Springer namespace {
72461dafd2SMatthias Springer /// Registers new ops and declares PDL as dependent dialect since the additional
73461dafd2SMatthias Springer /// ops are using PDL types for operands and results.
74461dafd2SMatthias Springer class BufferizationTransformDialectExtension
75461dafd2SMatthias Springer     : public transform::TransformDialectExtension<
76461dafd2SMatthias Springer           BufferizationTransformDialectExtension> {
77461dafd2SMatthias Springer public:
78*333ee218SAlex Zinenko   using Base::Base;
79*333ee218SAlex Zinenko 
init()80*333ee218SAlex Zinenko   void init() {
81461dafd2SMatthias Springer     declareDependentDialect<pdl::PDLDialect>();
82*333ee218SAlex Zinenko 
83*333ee218SAlex Zinenko     declareGeneratedDialect<bufferization::BufferizationDialect>();
84*333ee218SAlex Zinenko     declareGeneratedDialect<memref::MemRefDialect>();
85*333ee218SAlex Zinenko 
86461dafd2SMatthias Springer     registerTransformOps<
87461dafd2SMatthias Springer #define GET_OP_LIST
88461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
89461dafd2SMatthias Springer         >();
90461dafd2SMatthias Springer   }
91461dafd2SMatthias Springer };
92461dafd2SMatthias Springer } // namespace
93461dafd2SMatthias Springer 
94461dafd2SMatthias Springer #define GET_OP_CLASSES
95461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
96461dafd2SMatthias Springer 
registerTransformDialectExtension(DialectRegistry & registry)97461dafd2SMatthias Springer void mlir::bufferization::registerTransformDialectExtension(
98461dafd2SMatthias Springer     DialectRegistry &registry) {
99461dafd2SMatthias Springer   registry.addExtensions<BufferizationTransformDialectExtension>();
100461dafd2SMatthias Springer }
101