1461dafd2SMatthias Springer //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2461dafd2SMatthias Springer // 3461dafd2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4461dafd2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5461dafd2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6461dafd2SMatthias Springer // 7461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 8461dafd2SMatthias Springer 9461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10461dafd2SMatthias Springer 11461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14461dafd2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 15461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDL.h" 16461dafd2SMatthias Springer #include "mlir/Dialect/PDL/IR/PDLTypes.h" 17461dafd2SMatthias Springer #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18461dafd2SMatthias Springer 19461dafd2SMatthias Springer using namespace mlir; 20461dafd2SMatthias Springer using namespace mlir::bufferization; 21461dafd2SMatthias Springer using namespace mlir::transform; 22461dafd2SMatthias Springer 23461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 24461dafd2SMatthias Springer // OneShotBufferizeOp 25461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 26461dafd2SMatthias Springer 271d45282aSAlex Zinenko DiagnosedSilenceableFailure apply(TransformResults & transformResults,TransformState & state)28461dafd2SMatthias Springertransform::OneShotBufferizeOp::apply(TransformResults &transformResults, 29461dafd2SMatthias Springer TransformState &state) { 30461dafd2SMatthias Springer OneShotBufferizationOptions options; 31461dafd2SMatthias Springer options.allowReturnAllocs = getAllowReturnAllocs(); 32461dafd2SMatthias Springer options.allowUnknownOps = getAllowUnknownOps(); 33461dafd2SMatthias Springer options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 34461dafd2SMatthias Springer options.createDeallocs = getCreateDeallocs(); 35461dafd2SMatthias Springer options.testAnalysisOnly = getTestAnalysisOnly(); 36461dafd2SMatthias Springer options.printConflicts = getPrintConflicts(); 37461dafd2SMatthias Springer 38461dafd2SMatthias Springer ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 39461dafd2SMatthias Springer for (Operation *target : payloadOps) { 40461dafd2SMatthias Springer auto moduleOp = dyn_cast<ModuleOp>(target); 41461dafd2SMatthias Springer if (getTargetIsModule() && !moduleOp) 421d45282aSAlex Zinenko return emitSilenceableError() << "expected ModuleOp target"; 43461dafd2SMatthias Springer if (options.bufferizeFunctionBoundaries) { 44461dafd2SMatthias Springer if (!moduleOp) 451d45282aSAlex Zinenko return emitSilenceableError() << "expected ModuleOp target"; 46461dafd2SMatthias Springer if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 471d45282aSAlex Zinenko return emitSilenceableError() << "bufferization failed"; 48461dafd2SMatthias Springer } else { 49461dafd2SMatthias Springer if (failed(bufferization::runOneShotBufferize(target, options))) 501d45282aSAlex Zinenko return emitSilenceableError() << "bufferization failed"; 51461dafd2SMatthias Springer } 52461dafd2SMatthias Springer } 53461dafd2SMatthias Springer 541d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 55461dafd2SMatthias Springer } 56461dafd2SMatthias Springer getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)57461dafd2SMatthias Springervoid transform::OneShotBufferizeOp::getEffects( 58461dafd2SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 59461dafd2SMatthias Springer effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 60461dafd2SMatthias Springer TransformMappingResource::get()); 61461dafd2SMatthias Springer 62461dafd2SMatthias Springer // Handles that are not modules are not longer usable. 63461dafd2SMatthias Springer if (!getTargetIsModule()) 64461dafd2SMatthias Springer effects.emplace_back(MemoryEffects::Free::get(), getTarget(), 65461dafd2SMatthias Springer TransformMappingResource::get()); 66461dafd2SMatthias Springer } 67461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 68461dafd2SMatthias Springer // Transform op registration 69461dafd2SMatthias Springer //===----------------------------------------------------------------------===// 70461dafd2SMatthias Springer 71461dafd2SMatthias Springer namespace { 72461dafd2SMatthias Springer /// Registers new ops and declares PDL as dependent dialect since the additional 73461dafd2SMatthias Springer /// ops are using PDL types for operands and results. 74461dafd2SMatthias Springer class BufferizationTransformDialectExtension 75461dafd2SMatthias Springer : public transform::TransformDialectExtension< 76461dafd2SMatthias Springer BufferizationTransformDialectExtension> { 77461dafd2SMatthias Springer public: 78*333ee218SAlex Zinenko using Base::Base; 79*333ee218SAlex Zinenko init()80*333ee218SAlex Zinenko void init() { 81461dafd2SMatthias Springer declareDependentDialect<pdl::PDLDialect>(); 82*333ee218SAlex Zinenko 83*333ee218SAlex Zinenko declareGeneratedDialect<bufferization::BufferizationDialect>(); 84*333ee218SAlex Zinenko declareGeneratedDialect<memref::MemRefDialect>(); 85*333ee218SAlex Zinenko 86461dafd2SMatthias Springer registerTransformOps< 87461dafd2SMatthias Springer #define GET_OP_LIST 88461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 89461dafd2SMatthias Springer >(); 90461dafd2SMatthias Springer } 91461dafd2SMatthias Springer }; 92461dafd2SMatthias Springer } // namespace 93461dafd2SMatthias Springer 94461dafd2SMatthias Springer #define GET_OP_CLASSES 95461dafd2SMatthias Springer #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 96461dafd2SMatthias Springer registerTransformDialectExtension(DialectRegistry & registry)97461dafd2SMatthias Springervoid mlir::bufferization::registerTransformDialectExtension( 98461dafd2SMatthias Springer DialectRegistry ®istry) { 99461dafd2SMatthias Springer registry.addExtensions<BufferizationTransformDialectExtension>(); 100461dafd2SMatthias Springer } 101