1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 14 #include "mlir/IR/Operation.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 19 //===----------------------------------------------------------------------===// 20 // BufferizeTypeConverter 21 //===----------------------------------------------------------------------===// 22 23 static Value materializeToTensor(OpBuilder &builder, TensorType type, 24 ValueRange inputs, Location loc) { 25 assert(inputs.size() == 1); 26 assert(inputs[0].getType().isa<BaseMemRefType>()); 27 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 28 } 29 30 /// Registers conversions into BufferizeTypeConverter 31 BufferizeTypeConverter::BufferizeTypeConverter() { 32 // Keep all types unchanged. 33 addConversion([](Type type) { return type; }); 34 // Convert RankedTensorType to MemRefType. 35 addConversion([](RankedTensorType type) -> Type { 36 return MemRefType::get(type.getShape(), type.getElementType()); 37 }); 38 // Convert UnrankedTensorType to UnrankedMemRefType. 39 addConversion([](UnrankedTensorType type) -> Type { 40 return UnrankedMemRefType::get(type.getElementType(), 0); 41 }); 42 addArgumentMaterialization(materializeToTensor); 43 addSourceMaterialization(materializeToTensor); 44 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 45 ValueRange inputs, Location loc) -> Value { 46 assert(inputs.size() == 1); 47 assert(inputs[0].getType().isa<TensorType>()); 48 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 49 }); 50 } 51 52 void mlir::bufferization::populateBufferizeMaterializationLegality( 53 ConversionTarget &target) { 54 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 55 } 56 57 namespace { 58 // In a finalizing bufferize conversion, we know that all tensors have been 59 // converted to memrefs, thus, this op becomes an identity. 60 class BufferizeToTensorOp 61 : public OpConversionPattern<bufferization::ToTensorOp> { 62 public: 63 using OpConversionPattern::OpConversionPattern; 64 LogicalResult 65 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const override { 67 rewriter.replaceOp(op, adaptor.memref()); 68 return success(); 69 } 70 }; 71 } // namespace 72 73 namespace { 74 // In a finalizing bufferize conversion, we know that all tensors have been 75 // converted to memrefs, thus, this op becomes an identity. 76 class BufferizeToMemrefOp 77 : public OpConversionPattern<bufferization::ToMemrefOp> { 78 public: 79 using OpConversionPattern::OpConversionPattern; 80 LogicalResult 81 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 82 ConversionPatternRewriter &rewriter) const override { 83 rewriter.replaceOp(op, adaptor.tensor()); 84 return success(); 85 } 86 }; 87 } // namespace 88 89 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 90 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 91 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 92 patterns.getContext()); 93 } 94 95 namespace { 96 struct FinalizingBufferizePass 97 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 98 using FinalizingBufferizeBase< 99 FinalizingBufferizePass>::FinalizingBufferizeBase; 100 101 void runOnFunction() override { 102 auto func = getFunction(); 103 auto *context = &getContext(); 104 105 BufferizeTypeConverter typeConverter; 106 RewritePatternSet patterns(context); 107 ConversionTarget target(*context); 108 109 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 110 111 // If all result types are legal, and all block arguments are legal (ensured 112 // by func conversion above), then all types in the program are legal. 113 // 114 // We also check that the operand types are legal to avoid creating invalid 115 // IR. For example, this prevents 116 // populateEliminateBufferizeMaterializationsPatterns from updating the 117 // types of the operands to a return op without updating the enclosing 118 // function. 119 target.markUnknownOpDynamicallyLegal( 120 [&](Operation *op) { return typeConverter.isLegal(op); }); 121 122 if (failed(applyFullConversion(func, target, std::move(patterns)))) 123 signalPassFailure(); 124 } 125 }; 126 } // namespace 127 128 std::unique_ptr<FunctionPass> 129 mlir::bufferization::createFinalizingBufferizePass() { 130 return std::make_unique<FinalizingBufferizePass>(); 131 } 132