1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 25 //===----------------------------------------------------------------------===// 26 // BufferizeTypeConverter 27 //===----------------------------------------------------------------------===// 28 29 static Value materializeToTensor(OpBuilder &builder, TensorType type, 30 ValueRange inputs, Location loc) { 31 assert(inputs.size() == 1); 32 assert(inputs[0].getType().isa<BaseMemRefType>()); 33 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); 34 } 35 36 /// Registers conversions into BufferizeTypeConverter 37 BufferizeTypeConverter::BufferizeTypeConverter() { 38 // Keep all types unchanged. 39 addConversion([](Type type) { return type; }); 40 // Convert RankedTensorType to MemRefType. 41 addConversion([](RankedTensorType type) -> Type { 42 return MemRefType::get(type.getShape(), type.getElementType()); 43 }); 44 // Convert UnrankedTensorType to UnrankedMemRefType. 45 addConversion([](UnrankedTensorType type) -> Type { 46 return UnrankedMemRefType::get(type.getElementType(), 0); 47 }); 48 addArgumentMaterialization(materializeToTensor); 49 addSourceMaterialization(materializeToTensor); 50 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, 51 ValueRange inputs, Location loc) -> Value { 52 assert(inputs.size() == 1 && "expected exactly one input"); 53 54 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) { 55 // MemRef to MemRef cast. 56 assert(inputType != type && "expected different types"); 57 // Unranked to ranked and ranked to unranked casts must be explicit. 58 auto rankedDestType = type.dyn_cast<MemRefType>(); 59 if (!rankedDestType) 60 return nullptr; 61 FailureOr<Value> replacement = 62 castOrReallocMemRefValue(builder, inputs[0], rankedDestType); 63 if (failed(replacement)) 64 return nullptr; 65 return *replacement; 66 } 67 68 if (inputs[0].getType().isa<TensorType>()) { 69 // Tensor to MemRef cast. 70 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); 71 } 72 73 llvm_unreachable("only tensor/memref input types supported"); 74 }); 75 } 76 77 void mlir::bufferization::populateBufferizeMaterializationLegality( 78 ConversionTarget &target) { 79 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); 80 } 81 82 namespace { 83 // In a finalizing bufferize conversion, we know that all tensors have been 84 // converted to memrefs, thus, this op becomes an identity. 85 class BufferizeToTensorOp 86 : public OpConversionPattern<bufferization::ToTensorOp> { 87 public: 88 using OpConversionPattern::OpConversionPattern; 89 LogicalResult 90 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override { 92 rewriter.replaceOp(op, adaptor.memref()); 93 return success(); 94 } 95 }; 96 } // namespace 97 98 namespace { 99 // In a finalizing bufferize conversion, we know that all tensors have been 100 // converted to memrefs, thus, this op becomes an identity. 101 class BufferizeToMemrefOp 102 : public OpConversionPattern<bufferization::ToMemrefOp> { 103 public: 104 using OpConversionPattern::OpConversionPattern; 105 LogicalResult 106 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, 107 ConversionPatternRewriter &rewriter) const override { 108 rewriter.replaceOp(op, adaptor.tensor()); 109 return success(); 110 } 111 }; 112 } // namespace 113 114 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( 115 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 116 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, 117 patterns.getContext()); 118 } 119 120 namespace { 121 struct FinalizingBufferizePass 122 : public FinalizingBufferizeBase<FinalizingBufferizePass> { 123 using FinalizingBufferizeBase< 124 FinalizingBufferizePass>::FinalizingBufferizeBase; 125 126 void runOnOperation() override { 127 auto func = getOperation(); 128 auto *context = &getContext(); 129 130 BufferizeTypeConverter typeConverter; 131 RewritePatternSet patterns(context); 132 ConversionTarget target(*context); 133 134 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); 135 136 // If all result types are legal, and all block arguments are legal (ensured 137 // by func conversion above), then all types in the program are legal. 138 // 139 // We also check that the operand types are legal to avoid creating invalid 140 // IR. For example, this prevents 141 // populateEliminateBufferizeMaterializationsPatterns from updating the 142 // types of the operands to a return op without updating the enclosing 143 // function. 144 target.markUnknownOpDynamicallyLegal( 145 [&](Operation *op) { return typeConverter.isLegal(op); }); 146 147 if (failed(applyFullConversion(func, target, std::move(patterns)))) 148 signalPassFailure(); 149 } 150 }; 151 152 struct OneShotBufferizePass 153 : public OneShotBufferizeBase<OneShotBufferizePass> { 154 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {} 155 156 explicit OneShotBufferizePass(const AnalysisBufferizationOptions &options) 157 : options(options) {} 158 159 void getDependentDialects(DialectRegistry ®istry) const override { 160 registry.insert<bufferization::BufferizationDialect>(); 161 } 162 163 void runOnOperation() override { 164 AnalysisBufferizationOptions opt; 165 if (!options) { 166 // Make new bufferization options if none were provided when creating the 167 // pass. 168 opt.allowReturnMemref = allowReturnMemref; 169 opt.allowUnknownOps = allowUnknownOps; 170 opt.analysisFuzzerSeed = analysisFuzzerSeed; 171 opt.createDeallocs = createDeallocs; 172 opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; 173 opt.printConflicts = printConflicts; 174 opt.testAnalysisOnly = testAnalysisOnly; 175 176 BufferizationOptions::OpFilterEntry::FilterFn filterFn = 177 [&](Operation *op) { 178 // Disallow non-std dialect ops. I.e., no ops related to function 179 // calls. 180 if (op->getDialect()->getNamespace() == 181 StandardOpsDialect::getDialectNamespace()) 182 return false; 183 // Filter may be specified via options. 184 if (this->dialectFilter.hasValue()) 185 return llvm::find(this->dialectFilter, 186 op->getDialect()->getNamespace()) != 187 this->dialectFilter.end(); 188 // No filter specified: All other ops are allowed. 189 return true; 190 }; 191 opt.allowOperationInFilter(filterFn); 192 } else { 193 opt = *options; 194 } 195 196 ModuleOp moduleOp = getOperation(); 197 if (failed(runOneShotBufferize(moduleOp, opt))) { 198 signalPassFailure(); 199 return; 200 } 201 202 if (opt.testAnalysisOnly) 203 return; 204 205 OpPassManager cleanupPipeline("builtin.module"); 206 cleanupPipeline.addPass(createCanonicalizerPass()); 207 cleanupPipeline.addPass(createCSEPass()); 208 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); 209 (void)runPipeline(cleanupPipeline, moduleOp); 210 } 211 212 private: 213 llvm::Optional<AnalysisBufferizationOptions> options; 214 }; 215 } // namespace 216 217 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 218 return std::make_unique<OneShotBufferizePass>(); 219 } 220 221 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 222 const AnalysisBufferizationOptions &options) { 223 return std::make_unique<OneShotBufferizePass>(options); 224 } 225 226 std::unique_ptr<OperationPass<FuncOp>> 227 mlir::bufferization::createFinalizingBufferizePass() { 228 return std::make_unique<FinalizingBufferizePass>(); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // BufferizableOpInterface-based Bufferization 233 //===----------------------------------------------------------------------===// 234 235 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 236 237 /// Return true if the given op has a tensor result or a tensor operand. 238 static bool hasTensorSemantics(Operation *op) { 239 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 240 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 241 return hasTensorResult || hasTensorOperand; 242 } 243 244 /// Rewrite pattern that bufferizes bufferizable ops. 245 struct BufferizationPattern 246 : public OpInterfaceRewritePattern<BufferizableOpInterface> { 247 BufferizationPattern(MLIRContext *context, const BufferizationState &state, 248 PatternBenefit benefit = 1) 249 : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit), 250 state(state) {} 251 252 LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, 253 PatternRewriter &rewriter) const override { 254 // No tensors => no buffers. 255 if (!hasTensorSemantics(bufferizableOp.getOperation())) 256 return failure(); 257 if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) 258 return failure(); 259 return bufferizableOp.bufferize(rewriter, state); 260 } 261 262 private: 263 const BufferizationState &state; 264 }; 265 266 /// Check the result of bufferization. Return an error if an op was not 267 /// bufferized, unless partial bufferization is allowed. 268 static LogicalResult 269 checkBufferizationResult(Operation *op, const BufferizationOptions &options) { 270 if (!options.allowUnknownOps) { 271 // Check if all ops were bufferized. 272 LogicalResult status = success(); 273 op->walk([&](Operation *op) { 274 if (!hasTensorSemantics(op)) 275 return WalkResult::advance(); 276 277 // Bufferization dialect ops will canonicalize away if all other ops are 278 // bufferized. 279 if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op)) 280 return WalkResult::advance(); 281 282 // Ops that are not in the allow list can be ignored. 283 if (!options.isOpAllowed(op)) 284 return WalkResult::advance(); 285 286 // Ops without any uses and no side effects will fold away. 287 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) 288 return WalkResult::advance(); 289 290 status = op->emitError("op was not bufferized"); 291 return WalkResult::interrupt(); 292 }); 293 294 if (failed(status)) 295 return status; 296 } 297 298 return success(); 299 } 300 301 LogicalResult bufferization::bufferizeOp(Operation *op, 302 const BufferizationState &state) { 303 // Bufferize the op and its nested ops. 304 RewritePatternSet patterns(op->getContext()); 305 populateBufferizationPattern(state, patterns); 306 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) 307 return failure(); 308 309 return checkBufferizationResult(op, state.getOptions()); 310 } 311 312 namespace { 313 /// This a "no analysis, always copy" BufferizationState. In the absence of an 314 /// analysis, a buffer must be copied each time it is written to. Therefore, all 315 /// OpOperands that bufferize to a memory write must bufferize out-of-place. 316 class AlwaysCopyBufferizationState : public BufferizationState { 317 public: 318 AlwaysCopyBufferizationState(const BufferizationOptions &options) 319 : BufferizationState(options) {} 320 321 AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; 322 323 virtual ~AlwaysCopyBufferizationState() = default; 324 325 /// Return `true` if the given OpResult has been decided to bufferize inplace. 326 bool isInPlace(OpOperand &opOperand) const override { 327 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 328 // alloc and copy is inserted. 329 return !bufferizesToMemoryWrite(opOperand); 330 } 331 332 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 333 bool areEquivalentBufferizedValues(Value v1, Value v2) const override { 334 // There is no analysis, so we do not know if the values are equivalent. The 335 // conservative answer is "false". 336 return false; 337 } 338 }; 339 } // namespace 340 341 LogicalResult bufferization::bufferizeOp(Operation *op, 342 const BufferizationOptions &options) { 343 AlwaysCopyBufferizationState state(options); 344 return bufferizeOp(op, state); 345 } 346 347 void bufferization::populateBufferizationPattern( 348 const BufferizationState &state, RewritePatternSet &patterns) { 349 patterns.add<BufferizationPattern>(patterns.getContext(), state); 350 } 351 352 BufferizationOptions bufferization::getPartialBufferizationOptions() { 353 BufferizationOptions options; 354 options.allowReturnMemref = true; 355 options.allowUnknownOps = true; 356 options.createDeallocs = false; 357 options.fullyDynamicLayoutMaps = false; 358 return options; 359 } 360