1 //===- Bufferize.cpp - scf bufferize pass ---------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/SCF/Passes.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/SCF/Transforms.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::scf; 19 20 namespace { 21 struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> { 22 void runOnFunction() override { 23 auto func = getOperation(); 24 auto *context = &getContext(); 25 26 BufferizeTypeConverter typeConverter; 27 OwningRewritePatternList patterns; 28 ConversionTarget target(*context); 29 30 // TODO: Move this to BufferizeTypeConverter's constructor. 31 // 32 // This doesn't currently play well with "finalizing" bufferizations (ones 33 // that expect all materializations to be gone). In particular, there seems 34 // to at least be a double-free in the dialect conversion framework 35 // when this materialization gets inserted and then folded away because 36 // it is marked as illegal. 37 typeConverter.addArgumentMaterialization( 38 [](OpBuilder &builder, RankedTensorType type, ValueRange inputs, 39 Location loc) -> Value { 40 assert(inputs.size() == 1); 41 assert(inputs[0].getType().isa<BaseMemRefType>()); 42 return builder.create<TensorLoadOp>(loc, type, inputs[0]); 43 }); 44 45 populateBufferizeMaterializationLegality(target); 46 populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, 47 patterns, target); 48 if (failed(applyPartialConversion(func, target, std::move(patterns)))) 49 return signalPassFailure(); 50 }; 51 }; 52 } // end anonymous namespace 53 54 std::unique_ptr<Pass> mlir::createSCFBufferizePass() { 55 return std::make_unique<SCFBufferizePass>(); 56 } 57