1*461dafd2SMatthias Springer //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2*461dafd2SMatthias Springer // 3*461dafd2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*461dafd2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5*461dafd2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*461dafd2SMatthias Springer // 7*461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 8*461dafd2SMatthias Springer 9*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10*461dafd2SMatthias Springer 11*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14*461dafd2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 15*461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDL.h" 16*461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17*461dafd2SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18*461dafd2SMatthias Springer 19*461dafd2SMatthias Springer using namespace mlir; 20*461dafd2SMatthias Springer using namespace mlir::bufferization; 21*461dafd2SMatthias Springer using namespace mlir::transform; 22*461dafd2SMatthias Springer 23*461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 24*461dafd2SMatthias Springer // OneShotBufferizeOp 25*461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 26*461dafd2SMatthias Springer 27*461dafd2SMatthias Springer LogicalResult 28*461dafd2SMatthias Springer transform::OneShotBufferizeOp::apply(TransformResults &transformResults, 29*461dafd2SMatthias Springer TransformState &state) { 30*461dafd2SMatthias Springer OneShotBufferizationOptions options; 31*461dafd2SMatthias Springer options.allowReturnAllocs = getAllowReturnAllocs(); 32*461dafd2SMatthias Springer options.allowUnknownOps = getAllowUnknownOps(); 33*461dafd2SMatthias Springer options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 34*461dafd2SMatthias Springer options.createDeallocs = getCreateDeallocs(); 35*461dafd2SMatthias Springer options.testAnalysisOnly = getTestAnalysisOnly(); 36*461dafd2SMatthias Springer options.printConflicts = getPrintConflicts(); 37*461dafd2SMatthias Springer 38*461dafd2SMatthias Springer ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 39*461dafd2SMatthias Springer for (Operation *target : payloadOps) { 40*461dafd2SMatthias Springer auto moduleOp = dyn_cast<ModuleOp>(target); 41*461dafd2SMatthias Springer if (getTargetIsModule() && !moduleOp) 42*461dafd2SMatthias Springer return emitError("expected ModuleOp target"); 43*461dafd2SMatthias Springer if (options.bufferizeFunctionBoundaries) { 44*461dafd2SMatthias Springer if (!moduleOp) 45*461dafd2SMatthias Springer return emitError("expected ModuleOp target"); 46*461dafd2SMatthias Springer if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 47*461dafd2SMatthias Springer return emitError("bufferization failed"); 48*461dafd2SMatthias Springer } else { 49*461dafd2SMatthias Springer if (failed(bufferization::runOneShotBufferize(target, options))) 50*461dafd2SMatthias Springer return emitError("bufferization failed"); 51*461dafd2SMatthias Springer } 52*461dafd2SMatthias Springer } 53*461dafd2SMatthias Springer 54*461dafd2SMatthias Springer return success(); 55*461dafd2SMatthias Springer } 56*461dafd2SMatthias Springer 57*461dafd2SMatthias Springer void transform::OneShotBufferizeOp::getEffects( 58*461dafd2SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 59*461dafd2SMatthias Springer effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 60*461dafd2SMatthias Springer TransformMappingResource::get()); 61*461dafd2SMatthias Springer 62*461dafd2SMatthias Springer // Handles that are not modules are not longer usable. 63*461dafd2SMatthias Springer if (!getTargetIsModule()) 64*461dafd2SMatthias Springer effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 65*461dafd2SMatthias Springer TransformMappingResource::get()); 66*461dafd2SMatthias Springer } 67*461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 68*461dafd2SMatthias Springer // Transform op registration 69*461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 70*461dafd2SMatthias Springer 71*461dafd2SMatthias Springer namespace { 72*461dafd2SMatthias Springer /// Registers new ops and declares PDL as dependent dialect since the additional 73*461dafd2SMatthias Springer /// ops are using PDL types for operands and results. 74*461dafd2SMatthias Springer class BufferizationTransformDialectExtension 75*461dafd2SMatthias Springer : public transform::TransformDialectExtension< 76*461dafd2SMatthias Springer BufferizationTransformDialectExtension> { 77*461dafd2SMatthias Springer public: 78*461dafd2SMatthias Springer BufferizationTransformDialectExtension() { 79*461dafd2SMatthias Springer declareDependentDialect<bufferization::BufferizationDialect>(); 80*461dafd2SMatthias Springer declareDependentDialect<pdl::PDLDialect>(); 81*461dafd2SMatthias Springer declareDependentDialect<memref::MemRefDialect>(); 82*461dafd2SMatthias Springer registerTransformOps< 83*461dafd2SMatthias Springer #define GET_OP_LIST 84*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 85*461dafd2SMatthias Springer >(); 86*461dafd2SMatthias Springer } 87*461dafd2SMatthias Springer }; 88*461dafd2SMatthias Springer } // namespace 89*461dafd2SMatthias Springer 90*461dafd2SMatthias Springer #define GET_OP_CLASSES 91*461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 92*461dafd2SMatthias Springer 93*461dafd2SMatthias Springer void mlir::bufferization::registerTransformDialectExtension( 94*461dafd2SMatthias Springer DialectRegistry ®istry) { 95*461dafd2SMatthias Springer registry.addExtensions<BufferizationTransformDialectExtension>(); 96*461dafd2SMatthias Springer } 97