1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 21 //===----------------------------------------------------------------------===// 22 // BufferizeTypeConverter 23 //===----------------------------------------------------------------------===// 24 25 static Value materializeToTensor(OpBuilder &builder, TensorType type, 26 ValueRange inputs, Location loc) { 27 assert(inputs.size() == 1); 28 assert(inputs[0].getType().isa<BaseMemRefType>()); 29 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 30 } 31 32 /// Registers conversions into BufferizeTypeConverter 33 BufferizeTypeConverter::BufferizeTypeConverter() { 34 // Keep all types unchanged. 35 addConversion([](Type type) { return type; }); 36 // Convert RankedTensorType to MemRefType. 37 addConversion([](RankedTensorType type) -> Type { 38 return MemRefType::get(type.getShape(), type.getElementType()); 39 }); 40 // Convert UnrankedTensorType to UnrankedMemRefType. 41 addConversion([](UnrankedTensorType type) -> Type { 42 return UnrankedMemRefType::get(type.getElementType(), 0); 43 }); 44 addArgumentMaterialization(materializeToTensor); 45 addSourceMaterialization(materializeToTensor); 46 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 47 ValueRange inputs, Location loc) -> Value { 48 assert(inputs.size() == 1); 49 assert(inputs[0].getType().isa<TensorType>()); 50 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 51 }); 52 } 53 54 void mlir::bufferization::populateBufferizeMaterializationLegality( 55 ConversionTarget &target) { 56 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 57 } 58 59 namespace { 60 // In a finalizing bufferize conversion, we know that all tensors have been 61 // converted to memrefs, thus, this op becomes an identity. 62 class BufferizeToTensorOp 63 : public OpConversionPattern<bufferization::ToTensorOp> { 64 public: 65 using OpConversionPattern::OpConversionPattern; 66 LogicalResult 67 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 68 ConversionPatternRewriter &rewriter) const override { 69 rewriter.replaceOp(op, adaptor.memref()); 70 return success(); 71 } 72 }; 73 } // namespace 74 75 namespace { 76 // In a finalizing bufferize conversion, we know that all tensors have been 77 // converted to memrefs, thus, this op becomes an identity. 78 class BufferizeToMemrefOp 79 : public OpConversionPattern<bufferization::ToMemrefOp> { 80 public: 81 using OpConversionPattern::OpConversionPattern; 82 LogicalResult 83 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 84 ConversionPatternRewriter &rewriter) const override { 85 rewriter.replaceOp(op, adaptor.tensor()); 86 return success(); 87 } 88 }; 89 } // namespace 90 91 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 92 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 93 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 94 patterns.getContext()); 95 } 96 97 namespace { 98 struct FinalizingBufferizePass 99 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 100 using FinalizingBufferizeBase< 101 FinalizingBufferizePass>::FinalizingBufferizeBase; 102 103 void runOnOperation() override { 104 auto func = getOperation(); 105 auto *context = &getContext(); 106 107 BufferizeTypeConverter typeConverter; 108 RewritePatternSet patterns(context); 109 ConversionTarget target(*context); 110 111 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 112 113 // If all result types are legal, and all block arguments are legal (ensured 114 // by func conversion above), then all types in the program are legal. 115 // 116 // We also check that the operand types are legal to avoid creating invalid 117 // IR. For example, this prevents 118 // populateEliminateBufferizeMaterializationsPatterns from updating the 119 // types of the operands to a return op without updating the enclosing 120 // function. 121 target.markUnknownOpDynamicallyLegal( 122 [&](Operation *op) { return typeConverter.isLegal(op); }); 123 124 if (failed(applyFullConversion(func, target, std::move(patterns)))) 125 signalPassFailure(); 126 } 127 }; 128 } // namespace 129 130 std::unique_ptr<OperationPass<FuncOp>> 131 mlir::bufferization::createFinalizingBufferizePass() { 132 return std::make_unique<FinalizingBufferizePass>(); 133 } 134 135 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 136 137 /// Return true if the given op has a tensor result or a tensor operand. 138 static bool hasTensorSemantics(Operation *op) { 139 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 140 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 141 return hasTensorResult || hasTensorOperand; 142 } 143 144 /// Rewrite pattern that bufferizes bufferizable ops. 145 struct BufferizationPattern 146 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 147 BufferizationPattern(MLIRContext *context, const BufferizationState &state, 148 PatternBenefit benefit = 1) 149 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 150 state(state) {} 151 152 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 153 PatternRewriter &rewriter) const override { 154 // No tensors => no buffers. 155 if (!hasTensorSemantics(bufferizableOp.getOperation())) 156 return failure(); 157 if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) 158 return failure(); 159 return bufferizableOp.bufferize(rewriter, state); 160 } 161 162 private: 163 const BufferizationState &state; 164 }; 165 166 /// Check the result of bufferization. Return an error if an op was not 167 /// bufferized, unless partial bufferization is allowed. 168 static LogicalResult 169 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 170 if (!options.allowUnknownOps) { 171 // Check if all ops were bufferized. 172 LogicalResult status = success(); 173 op->walk([&](Operation *op) { 174 if (!hasTensorSemantics(op)) 175 return WalkResult::advance(); 176 177 // Bufferization dialect ops will canonicalize away if all other ops are 178 // bufferized. 179 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 180 return WalkResult::advance(); 181 182 // Ops that are not in the allow list can be ignored. 183 if (!options.isOpAllowed(op)) 184 return WalkResult::advance(); 185 186 // Ops without any uses and no side effects will fold away. 187 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 188 return WalkResult::advance(); 189 190 status = op->emitError("op was not bufferized"); 191 return WalkResult::interrupt(); 192 }); 193 194 if (failed(status)) 195 return status; 196 } 197 198 return success(); 199 } 200 201 LogicalResult bufferization::bufferizeOp(Operation *op, 202 const BufferizationState &state) { 203 // Bufferize the op and its nested ops. 204 OwningRewritePatternList patterns(op->getContext()); 205 patterns.add<BufferizationPattern>(op->getContext(), state); 206 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) 207 return failure(); 208 209 return checkBufferizationResult(op, state.getOptions()); 210 } 211