1a54f4eaeSMogball //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===// 2a54f4eaeSMogball // 3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information. 5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a54f4eaeSMogball // 7a54f4eaeSMogball //===----------------------------------------------------------------------===// 8a54f4eaeSMogball 9a54f4eaeSMogball #include "PassDetail.h" 10f89bb3c0SAlexander Belyaev 11eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 14075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1557470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 17a54f4eaeSMogball #include "mlir/Dialect/MemRef/IR/MemRef.h" 18a54f4eaeSMogball 19a54f4eaeSMogball using namespace mlir; 20075e3fddSMatthias Springer using namespace bufferization; 21a54f4eaeSMogball 22a54f4eaeSMogball namespace { 23a54f4eaeSMogball /// Pass to bufferize Arithmetic ops. 24a54f4eaeSMogball struct ArithmeticBufferizePass 25a54f4eaeSMogball : public ArithmeticBufferizeBase<ArithmeticBufferizePass> { ArithmeticBufferizePass__anonf27b07370111::ArithmeticBufferizePass26ab47418dSMatthias Springer ArithmeticBufferizePass(uint64_t alignment = 0, bool constantOpOnly = false) 27ab47418dSMatthias Springer : ArithmeticBufferizeBase<ArithmeticBufferizePass>(), 28ab47418dSMatthias Springer constantOpOnly(constantOpOnly) { 29ab47418dSMatthias Springer this->alignment = alignment; 30ab47418dSMatthias Springer } 31ab47418dSMatthias Springer runOnOperation__anonf27b07370111::ArithmeticBufferizePass3241574554SRiver Riddle void runOnOperation() override { 33cdb7675cSMatthias Springer BufferizationOptions options = getPartialBufferizationOptions(); 34ab47418dSMatthias Springer if (constantOpOnly) { 35*1534177fSMatthias Springer options.opFilter.allowOperation<arith::ConstantOp>(); 36ab47418dSMatthias Springer } else { 37*1534177fSMatthias Springer options.opFilter.allowDialect<arith::ArithmeticDialect>(); 38ab47418dSMatthias Springer } 39cdb7675cSMatthias Springer options.bufferAlignment = alignment; 40a54f4eaeSMogball 41cdb7675cSMatthias Springer if (failed(bufferizeOp(getOperation(), options))) 42a54f4eaeSMogball signalPassFailure(); 43a54f4eaeSMogball } 44a54f4eaeSMogball getDependentDialects__anonf27b07370111::ArithmeticBufferizePass45075e3fddSMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override { 46075e3fddSMatthias Springer registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect, 47075e3fddSMatthias Springer arith::ArithmeticDialect>(); 48075e3fddSMatthias Springer arith::registerBufferizableOpInterfaceExternalModels(registry); 49a54f4eaeSMogball } 50ab47418dSMatthias Springer 51ab47418dSMatthias Springer private: 52ab47418dSMatthias Springer bool constantOpOnly; 53075e3fddSMatthias Springer }; 54075e3fddSMatthias Springer } // namespace 55a54f4eaeSMogball createArithmeticBufferizePass()56a54f4eaeSMogballstd::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() { 57a54f4eaeSMogball return std::make_unique<ArithmeticBufferizePass>(); 58a54f4eaeSMogball } 59ab47418dSMatthias Springer 60ab47418dSMatthias Springer std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment)61ab47418dSMatthias Springermlir::arith::createConstantBufferizePass(uint64_t alignment) { 62ab47418dSMatthias Springer return std::make_unique<ArithmeticBufferizePass>(alignment, 63ab47418dSMatthias Springer /*constantOpOnly=*/true); 64ab47418dSMatthias Springer } 65