1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::transform; 22 23 //===----------------------------------------------------------------------===// 24 // OneShotBufferizeOp 25 //===----------------------------------------------------------------------===// 26 27 DiagnosedSilenceableFailure 28 transform::OneShotBufferizeOp::apply(TransformResults &transformResults, 29 TransformState &state) { 30 OneShotBufferizationOptions options; 31 options.allowReturnAllocs = getAllowReturnAllocs(); 32 options.allowUnknownOps = getAllowUnknownOps(); 33 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 34 options.createDeallocs = getCreateDeallocs(); 35 options.testAnalysisOnly = getTestAnalysisOnly(); 36 options.printConflicts = getPrintConflicts(); 37 38 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 39 for (Operation *target : payloadOps) { 40 auto moduleOp = dyn_cast<ModuleOp>(target); 41 if (getTargetIsModule() && !moduleOp) 42 return emitSilenceableError() << "expected ModuleOp target"; 43 if (options.bufferizeFunctionBoundaries) { 44 if (!moduleOp) 45 return emitSilenceableError() << "expected ModuleOp target"; 46 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 47 return emitSilenceableError() << "bufferization failed"; 48 } else { 49 if (failed(bufferization::runOneShotBufferize(target, options))) 50 return emitSilenceableError() << "bufferization failed"; 51 } 52 } 53 54 return DiagnosedSilenceableFailure::success(); 55 } 56 57 void transform::OneShotBufferizeOp::getEffects( 58 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 59 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 60 TransformMappingResource::get()); 61 62 // Handles that are not modules are not longer usable. 63 if (!getTargetIsModule()) 64 effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 65 TransformMappingResource::get()); 66 } 67 //===----------------------------------------------------------------------===// 68 // Transform op registration 69 //===----------------------------------------------------------------------===// 70 71 namespace { 72 /// Registers new ops and declares PDL as dependent dialect since the additional 73 /// ops are using PDL types for operands and results. 74 class BufferizationTransformDialectExtension 75 : public transform::TransformDialectExtension< 76 BufferizationTransformDialectExtension> { 77 public: 78 BufferizationTransformDialectExtension() { 79 declareDependentDialect<bufferization::BufferizationDialect>(); 80 declareDependentDialect<pdl::PDLDialect>(); 81 declareDependentDialect<memref::MemRefDialect>(); 82 registerTransformOps< 83 #define GET_OP_LIST 84 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 85 >(); 86 } 87 }; 88 } // namespace 89 90 #define GET_OP_CLASSES 91 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 92 93 void mlir::bufferization::registerTransformDialectExtension( 94 DialectRegistry ®istry) { 95 registry.addExtensions<BufferizationTransformDialectExtension>(); 96 } 97