1*461dafd2SMatthias Springer //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2*461dafd2SMatthias Springer //
3*461dafd2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*461dafd2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5*461dafd2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*461dafd2SMatthias Springer //
7*461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
8*461dafd2SMatthias Springer 
9*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10*461dafd2SMatthias Springer 
11*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14*461dafd2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
15*461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDL.h"
16*461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17*461dafd2SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18*461dafd2SMatthias Springer 
19*461dafd2SMatthias Springer using namespace mlir;
20*461dafd2SMatthias Springer using namespace mlir::bufferization;
21*461dafd2SMatthias Springer using namespace mlir::transform;
22*461dafd2SMatthias Springer 
23*461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
24*461dafd2SMatthias Springer // OneShotBufferizeOp
25*461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
26*461dafd2SMatthias Springer 
27*461dafd2SMatthias Springer LogicalResult
28*461dafd2SMatthias Springer transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
29*461dafd2SMatthias Springer                                      TransformState &state) {
30*461dafd2SMatthias Springer   OneShotBufferizationOptions options;
31*461dafd2SMatthias Springer   options.allowReturnAllocs = getAllowReturnAllocs();
32*461dafd2SMatthias Springer   options.allowUnknownOps = getAllowUnknownOps();
33*461dafd2SMatthias Springer   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
34*461dafd2SMatthias Springer   options.createDeallocs = getCreateDeallocs();
35*461dafd2SMatthias Springer   options.testAnalysisOnly = getTestAnalysisOnly();
36*461dafd2SMatthias Springer   options.printConflicts = getPrintConflicts();
37*461dafd2SMatthias Springer 
38*461dafd2SMatthias Springer   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
39*461dafd2SMatthias Springer   for (Operation *target : payloadOps) {
40*461dafd2SMatthias Springer     auto moduleOp = dyn_cast<ModuleOp>(target);
41*461dafd2SMatthias Springer     if (getTargetIsModule() && !moduleOp)
42*461dafd2SMatthias Springer       return emitError("expected ModuleOp target");
43*461dafd2SMatthias Springer     if (options.bufferizeFunctionBoundaries) {
44*461dafd2SMatthias Springer       if (!moduleOp)
45*461dafd2SMatthias Springer         return emitError("expected ModuleOp target");
46*461dafd2SMatthias Springer       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
47*461dafd2SMatthias Springer         return emitError("bufferization failed");
48*461dafd2SMatthias Springer     } else {
49*461dafd2SMatthias Springer       if (failed(bufferization::runOneShotBufferize(target, options)))
50*461dafd2SMatthias Springer         return emitError("bufferization failed");
51*461dafd2SMatthias Springer     }
52*461dafd2SMatthias Springer   }
53*461dafd2SMatthias Springer 
54*461dafd2SMatthias Springer   return success();
55*461dafd2SMatthias Springer }
56*461dafd2SMatthias Springer 
57*461dafd2SMatthias Springer void transform::OneShotBufferizeOp::getEffects(
58*461dafd2SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
59*461dafd2SMatthias Springer   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
60*461dafd2SMatthias Springer                        TransformMappingResource::get());
61*461dafd2SMatthias Springer 
62*461dafd2SMatthias Springer   // Handles that are not modules are not longer usable.
63*461dafd2SMatthias Springer   if (!getTargetIsModule())
64*461dafd2SMatthias Springer     effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
65*461dafd2SMatthias Springer                          TransformMappingResource::get());
66*461dafd2SMatthias Springer }
67*461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
68*461dafd2SMatthias Springer // Transform op registration
69*461dafd2SMatthias Springer //===----------------------------------------------------------------------===//
70*461dafd2SMatthias Springer 
71*461dafd2SMatthias Springer namespace {
72*461dafd2SMatthias Springer /// Registers new ops and declares PDL as dependent dialect since the additional
73*461dafd2SMatthias Springer /// ops are using PDL types for operands and results.
74*461dafd2SMatthias Springer class BufferizationTransformDialectExtension
75*461dafd2SMatthias Springer     : public transform::TransformDialectExtension<
76*461dafd2SMatthias Springer           BufferizationTransformDialectExtension> {
77*461dafd2SMatthias Springer public:
78*461dafd2SMatthias Springer   BufferizationTransformDialectExtension() {
79*461dafd2SMatthias Springer     declareDependentDialect<bufferization::BufferizationDialect>();
80*461dafd2SMatthias Springer     declareDependentDialect<pdl::PDLDialect>();
81*461dafd2SMatthias Springer     declareDependentDialect<memref::MemRefDialect>();
82*461dafd2SMatthias Springer     registerTransformOps<
83*461dafd2SMatthias Springer #define GET_OP_LIST
84*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
85*461dafd2SMatthias Springer         >();
86*461dafd2SMatthias Springer   }
87*461dafd2SMatthias Springer };
88*461dafd2SMatthias Springer } // namespace
89*461dafd2SMatthias Springer 
90*461dafd2SMatthias Springer #define GET_OP_CLASSES
91*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
92*461dafd2SMatthias Springer 
93*461dafd2SMatthias Springer void mlir::bufferization::registerTransformDialectExtension(
94*461dafd2SMatthias Springer     DialectRegistry &registry) {
95*461dafd2SMatthias Springer   registry.addExtensions<BufferizationTransformDialectExtension>();
96*461dafd2SMatthias Springer }
97