1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::transform;
22 
23 //===----------------------------------------------------------------------===//
24 // OneShotBufferizeOp
25 //===----------------------------------------------------------------------===//
26 
27 DiagnosedSilenceableFailure
28 transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
29                                      TransformState &state) {
30   OneShotBufferizationOptions options;
31   options.allowReturnAllocs = getAllowReturnAllocs();
32   options.allowUnknownOps = getAllowUnknownOps();
33   options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
34   options.createDeallocs = getCreateDeallocs();
35   options.testAnalysisOnly = getTestAnalysisOnly();
36   options.printConflicts = getPrintConflicts();
37 
38   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
39   for (Operation *target : payloadOps) {
40     auto moduleOp = dyn_cast<ModuleOp>(target);
41     if (getTargetIsModule() && !moduleOp)
42       return emitSilenceableError() << "expected ModuleOp target";
43     if (options.bufferizeFunctionBoundaries) {
44       if (!moduleOp)
45         return emitSilenceableError() << "expected ModuleOp target";
46       if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
47         return emitSilenceableError() << "bufferization failed";
48     } else {
49       if (failed(bufferization::runOneShotBufferize(target, options)))
50         return emitSilenceableError() << "bufferization failed";
51     }
52   }
53 
54   return DiagnosedSilenceableFailure::success();
55 }
56 
57 void transform::OneShotBufferizeOp::getEffects(
58     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
59   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
60                        TransformMappingResource::get());
61 
62   // Handles that are not modules are not longer usable.
63   if (!getTargetIsModule())
64     effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
65                          TransformMappingResource::get());
66 }
67 //===----------------------------------------------------------------------===//
68 // Transform op registration
69 //===----------------------------------------------------------------------===//
70 
71 namespace {
72 /// Registers new ops and declares PDL as dependent dialect since the additional
73 /// ops are using PDL types for operands and results.
74 class BufferizationTransformDialectExtension
75     : public transform::TransformDialectExtension<
76           BufferizationTransformDialectExtension> {
77 public:
78   using Base::Base;
79 
80   void init() {
81     declareDependentDialect<pdl::PDLDialect>();
82 
83     declareGeneratedDialect<bufferization::BufferizationDialect>();
84     declareGeneratedDialect<memref::MemRefDialect>();
85 
86     registerTransformOps<
87 #define GET_OP_LIST
88 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
89         >();
90   }
91 };
92 } // namespace
93 
94 #define GET_OP_CLASSES
95 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
96 
97 void mlir::bufferization::registerTransformDialectExtension(
98     DialectRegistry &registry) {
99   registry.addExtensions<BufferizationTransformDialectExtension>();
100 }
101