1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/Passes.h"
24
25 using namespace mlir;
26 using namespace mlir::bufferization;
27
28 //===----------------------------------------------------------------------===//
29 // BufferizeTypeConverter
30 //===----------------------------------------------------------------------===//
31
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)32 static Value materializeToTensor(OpBuilder &builder, TensorType type,
33 ValueRange inputs, Location loc) {
34 assert(inputs.size() == 1);
35 assert(inputs[0].getType().isa<BaseMemRefType>());
36 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
37 }
38
39 /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()40 BufferizeTypeConverter::BufferizeTypeConverter() {
41 // Keep all types unchanged.
42 addConversion([](Type type) { return type; });
43 // Convert RankedTensorType to MemRefType.
44 addConversion([](RankedTensorType type) -> Type {
45 return MemRefType::get(type.getShape(), type.getElementType());
46 });
47 // Convert UnrankedTensorType to UnrankedMemRefType.
48 addConversion([](UnrankedTensorType type) -> Type {
49 return UnrankedMemRefType::get(type.getElementType(), 0);
50 });
51 addArgumentMaterialization(materializeToTensor);
52 addSourceMaterialization(materializeToTensor);
53 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
54 ValueRange inputs, Location loc) -> Value {
55 assert(inputs.size() == 1 && "expected exactly one input");
56
57 if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
58 // MemRef to MemRef cast.
59 assert(inputType != type && "expected different types");
60 // Unranked to ranked and ranked to unranked casts must be explicit.
61 auto rankedDestType = type.dyn_cast<MemRefType>();
62 if (!rankedDestType)
63 return nullptr;
64 FailureOr<Value> replacement =
65 castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
66 if (failed(replacement))
67 return nullptr;
68 return *replacement;
69 }
70
71 if (inputs[0].getType().isa<TensorType>()) {
72 // Tensor to MemRef cast.
73 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
74 }
75
76 llvm_unreachable("only tensor/memref input types supported");
77 });
78 }
79
populateBufferizeMaterializationLegality(ConversionTarget & target)80 void mlir::bufferization::populateBufferizeMaterializationLegality(
81 ConversionTarget &target) {
82 target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
83 }
84
85 namespace {
86 // In a finalizing bufferize conversion, we know that all tensors have been
87 // converted to memrefs, thus, this op becomes an identity.
88 class BufferizeToTensorOp
89 : public OpConversionPattern<bufferization::ToTensorOp> {
90 public:
91 using OpConversionPattern::OpConversionPattern;
92 LogicalResult
matchAndRewrite(bufferization::ToTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const93 matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
94 ConversionPatternRewriter &rewriter) const override {
95 rewriter.replaceOp(op, adaptor.getMemref());
96 return success();
97 }
98 };
99 } // namespace
100
101 namespace {
102 // In a finalizing bufferize conversion, we know that all tensors have been
103 // converted to memrefs, thus, this op becomes an identity.
104 class BufferizeToMemrefOp
105 : public OpConversionPattern<bufferization::ToMemrefOp> {
106 public:
107 using OpConversionPattern::OpConversionPattern;
108 LogicalResult
matchAndRewrite(bufferization::ToMemrefOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const109 matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 rewriter.replaceOp(op, adaptor.getTensor());
112 return success();
113 }
114 };
115 } // namespace
116
populateEliminateBufferizeMaterializationsPatterns(BufferizeTypeConverter & typeConverter,RewritePatternSet & patterns)117 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
118 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
119 patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
120 patterns.getContext());
121 }
122
123 namespace {
124 struct FinalizingBufferizePass
125 : public FinalizingBufferizeBase<FinalizingBufferizePass> {
126 using FinalizingBufferizeBase<
127 FinalizingBufferizePass>::FinalizingBufferizeBase;
128
runOnOperation__anon82ae79a50711::FinalizingBufferizePass129 void runOnOperation() override {
130 auto func = getOperation();
131 auto *context = &getContext();
132
133 BufferizeTypeConverter typeConverter;
134 RewritePatternSet patterns(context);
135 ConversionTarget target(*context);
136
137 populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
138
139 // If all result types are legal, and all block arguments are legal (ensured
140 // by func conversion above), then all types in the program are legal.
141 //
142 // We also check that the operand types are legal to avoid creating invalid
143 // IR. For example, this prevents
144 // populateEliminateBufferizeMaterializationsPatterns from updating the
145 // types of the operands to a return op without updating the enclosing
146 // function.
147 target.markUnknownOpDynamicallyLegal(
148 [&](Operation *op) { return typeConverter.isLegal(op); });
149
150 if (failed(applyFullConversion(func, target, std::move(patterns))))
151 signalPassFailure();
152 }
153 };
154
155 static BufferizationOptions::LayoutMapOption
parseLayoutMapOption(const std::string & s)156 parseLayoutMapOption(const std::string &s) {
157 if (s == "fully-dynamic-layout-map")
158 return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
159 if (s == "identity-layout-map")
160 return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
161 if (s == "infer-layout-map")
162 return BufferizationOptions::LayoutMapOption::InferLayoutMap;
163 llvm_unreachable("invalid layout map option");
164 }
165
166 struct OneShotBufferizePass
167 : public OneShotBufferizeBase<OneShotBufferizePass> {
OneShotBufferizePass__anon82ae79a50711::OneShotBufferizePass168 OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
169
OneShotBufferizePass__anon82ae79a50711::OneShotBufferizePass170 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
171 : options(options) {}
172
getDependentDialects__anon82ae79a50711::OneShotBufferizePass173 void getDependentDialects(DialectRegistry ®istry) const override {
174 registry
175 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
176 registerAllocationOpInterfaceExternalModels(registry);
177 }
178
runOnOperation__anon82ae79a50711::OneShotBufferizePass179 void runOnOperation() override {
180 OneShotBufferizationOptions opt;
181 if (!options) {
182 // Make new bufferization options if none were provided when creating the
183 // pass.
184 opt.allowReturnAllocs = allowReturnAllocs;
185 opt.allowUnknownOps = allowUnknownOps;
186 opt.analysisFuzzerSeed = analysisFuzzerSeed;
187 opt.createDeallocs = createDeallocs;
188 opt.functionBoundaryTypeConversion =
189 parseLayoutMapOption(functionBoundaryTypeConversion);
190 if (mustInferMemorySpace)
191 opt.defaultMemorySpace = None;
192 opt.printConflicts = printConflicts;
193 opt.testAnalysisOnly = testAnalysisOnly;
194 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
195
196 // Configure type converter.
197 BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
198 parseLayoutMapOption(unknownTypeConversion);
199 opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
200 const BufferizationOptions &options) {
201 auto tensorType = value.getType().cast<TensorType>();
202 if (unknownTypeConversionOption ==
203 BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
204 return bufferization::getMemRefTypeWithStaticIdentityLayout(
205 tensorType, memorySpace);
206 assert(
207 unknownTypeConversionOption ==
208 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
209 "invalid layout map option");
210 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
211 memorySpace);
212 };
213
214 // Configure op filter.
215 OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
216 // Filter may be specified via options.
217 if (this->dialectFilter.hasValue())
218 return llvm::is_contained(this->dialectFilter,
219 op->getDialect()->getNamespace());
220 // No filter specified: All other ops are allowed.
221 return true;
222 };
223 opt.opFilter.allowOperation(filterFn);
224 } else {
225 opt = *options;
226 }
227
228 ModuleOp moduleOp = getOperation();
229 if (opt.bufferizeFunctionBoundaries) {
230 if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
231 signalPassFailure();
232 return;
233 }
234 } else {
235 if (failed(runOneShotBufferize(moduleOp, opt))) {
236 signalPassFailure();
237 return;
238 }
239 }
240
241 if (opt.testAnalysisOnly)
242 return;
243
244 OpPassManager cleanupPipeline("builtin.module");
245 cleanupPipeline.addPass(createCanonicalizerPass());
246 cleanupPipeline.addPass(createCSEPass());
247 cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
248 (void)runPipeline(cleanupPipeline, moduleOp);
249 }
250
251 private:
252 llvm::Optional<OneShotBufferizationOptions> options;
253 };
254 } // namespace
255
256 namespace {
257 struct BufferizationBufferizePass
258 : public BufferizationBufferizeBase<BufferizationBufferizePass> {
runOnOperation__anon82ae79a50b11::BufferizationBufferizePass259 void runOnOperation() override {
260 BufferizationOptions options = getPartialBufferizationOptions();
261 options.opFilter.allowDialect<BufferizationDialect>();
262
263 if (failed(bufferizeOp(getOperation(), options)))
264 signalPassFailure();
265 }
266
getDependentDialects__anon82ae79a50b11::BufferizationBufferizePass267 void getDependentDialects(DialectRegistry ®istry) const override {
268 registry
269 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
270 }
271 };
272 } // namespace
273
createBufferizationBufferizePass()274 std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
275 return std::make_unique<BufferizationBufferizePass>();
276 }
277
createOneShotBufferizePass()278 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
279 return std::make_unique<OneShotBufferizePass>();
280 }
281
createOneShotBufferizePass(const OneShotBufferizationOptions & options)282 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
283 const OneShotBufferizationOptions &options) {
284 return std::make_unique<OneShotBufferizePass>(options);
285 }
286
287 std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBufferizePass()288 mlir::bufferization::createFinalizingBufferizePass() {
289 return std::make_unique<FinalizingBufferizePass>();
290 }
291
292 //===----------------------------------------------------------------------===//
293 // BufferizableOpInterface-based Bufferization
294 //===----------------------------------------------------------------------===//
295
isaTensor(Type t)296 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
297
298 /// Return true if the given op has a tensor result or a tensor operand.
hasTensorSemantics(Operation * op)299 static bool hasTensorSemantics(Operation *op) {
300 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
301 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
302 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
303 return hasTensorArg || hasTensorResult;
304 }
305
306 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
307 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
308 return hasTensorResult || hasTensorOperand;
309 }
310
311 namespace {
312 /// A rewriter that keeps track of extra information during bufferization.
313 class BufferizationRewriter : public IRRewriter {
314 public:
BufferizationRewriter(MLIRContext * ctx,DenseSet<Operation * > & erasedOps,DenseSet<Operation * > & toMemrefOps,SmallVector<Operation * > & worklist,const BufferizationOptions & options,const OpFilter * opFilter)315 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
316 DenseSet<Operation *> &toMemrefOps,
317 SmallVector<Operation *> &worklist,
318 const BufferizationOptions &options,
319 const OpFilter *opFilter)
320 : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
321 worklist(worklist), analysisState(options), opFilter(opFilter) {}
322
323 protected:
notifyOperationRemoved(Operation * op)324 void notifyOperationRemoved(Operation *op) override {
325 IRRewriter::notifyOperationRemoved(op);
326 erasedOps.insert(op);
327 // Erase if present.
328 toMemrefOps.erase(op);
329 }
330
notifyOperationInserted(Operation * op)331 void notifyOperationInserted(Operation *op) override {
332 IRRewriter::notifyOperationInserted(op);
333 erasedOps.erase(op);
334
335 // Keep track of to_memref ops.
336 if (isa<ToMemrefOp>(op)) {
337 toMemrefOps.insert(op);
338 return;
339 }
340
341 // Skip to_tensor ops.
342 if (isa<ToTensorOp>(op))
343 return;
344
345 // Skip non-tensor ops.
346 if (!hasTensorSemantics(op))
347 return;
348
349 // Skip ops that are not allowed to be bufferized.
350 auto const &options = analysisState.getOptions();
351 if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
352 return;
353
354 #ifndef NDEBUG
355 // Read-only tensor ops may be created during bufferization. Ops that are
356 // writing should not be created because such ops were never analyzed.
357 // Bufferizing such ops could introduce a RaW conflict.
358 for (OpOperand &operand : op->getOpOperands())
359 if (operand.get().getType().isa<TensorType>())
360 assert(!analysisState.bufferizesToMemoryWrite(operand) &&
361 "creating tensor ops that bufferize to a memory write is not "
362 "allowed during bufferization");
363 #endif // NDEBUG
364
365 // Add op to worklist.
366 worklist.push_back(op);
367 }
368
369 private:
370 /// A set of all erased ops.
371 DenseSet<Operation *> &erasedOps;
372
373 /// A set of all to_memref ops.
374 DenseSet<Operation *> &toMemrefOps;
375
376 /// The worklist of ops to be bufferized.
377 SmallVector<Operation *> &worklist;
378
379 /// The analysis state. Used for debug assertions and access to the
380 /// bufferization options.
381 const AnalysisState analysisState;
382
383 /// An extra op filter for bufferization.
384 const OpFilter *opFilter;
385 };
386 } // namespace
387
bufferizeOp(Operation * op,const BufferizationOptions & options,bool copyBeforeWrite,const OpFilter * opFilter)388 LogicalResult bufferization::bufferizeOp(Operation *op,
389 const BufferizationOptions &options,
390 bool copyBeforeWrite,
391 const OpFilter *opFilter) {
392 if (copyBeforeWrite) {
393 AnalysisState state(options);
394 if (failed(insertTensorCopies(op, state)))
395 return failure();
396 }
397
398 // Keep track of to_memref ops.
399 DenseSet<Operation *> toMemrefOps;
400 op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
401
402 // Gather all bufferizable ops in top-to-bottom order.
403 //
404 // We should ideally know the exact memref type of all operands when
405 // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
406 // Otherwise, we have to use a memref type with a fully dynamic layout map to
407 // avoid copies. We are currently missing patterns for layout maps to
408 // canonicalize away (or canonicalize to more precise layouts).
409 //
410 // FuncOps must be bufferized before their bodies, so add them to the worklist
411 // first.
412 SmallVector<Operation *> worklist;
413 op->walk([&](func::FuncOp funcOp) {
414 if (hasTensorSemantics(funcOp))
415 worklist.push_back(funcOp);
416 });
417 op->walk<WalkOrder::PostOrder>([&](Operation *op) {
418 if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
419 worklist.push_back(op);
420 });
421
422 // Keep track of all erased ops.
423 DenseSet<Operation *> erasedOps;
424
425 // Bufferize all ops.
426 BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
427 worklist, options, opFilter);
428 for (unsigned i = 0; i < worklist.size(); ++i) {
429 Operation *op = worklist[i];
430 // Skip ops that were erased.
431 if (erasedOps.contains(op))
432 continue;
433 // Skip ops that are not bufferizable or not allowed.
434 auto bufferizableOp = options.dynCastBufferizableOp(op);
435 if (!bufferizableOp)
436 continue;
437 if (opFilter && !opFilter->isOpAllowed(op))
438 continue;
439 // Skip ops that no longer have tensor semantics.
440 if (!hasTensorSemantics(op))
441 continue;
442 // Bufferize the op.
443 rewriter.setInsertionPoint(op);
444 if (failed(bufferizableOp.bufferize(rewriter, options)))
445 return op->emitError("failed to bufferize op");
446 }
447
448 // Fold all to_memref(to_tensor(x)) pairs.
449 for (Operation *op : toMemrefOps) {
450 rewriter.setInsertionPoint(op);
451 (void)bufferization::foldToMemrefToTensorPair(rewriter,
452 cast<ToMemrefOp>(op));
453 }
454
455 /// Check the result of bufferization. Return an error if an op was not
456 /// bufferized, unless partial bufferization is allowed.
457 if (options.allowUnknownOps)
458 return success();
459
460 for (Operation *op : worklist) {
461 // Skip ops that are entirely gone.
462 if (erasedOps.contains(op))
463 continue;
464 // Ops that no longer have tensor semantics (because they were updated
465 // in-place) are allowed.
466 if (!hasTensorSemantics(op))
467 continue;
468 // Continue ops that are not allowed.
469 if (!options.isOpAllowed(op))
470 continue;
471 if (opFilter && !opFilter->isOpAllowed(op))
472 continue;
473 // Ops without any uses and no side effects will fold away.
474 if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
475 continue;
476 // ToTensorOps/ToMemrefOps are allowed in the output.
477 if (isa<ToTensorOp, ToMemrefOp>(op))
478 continue;
479 return op->emitError("op was not bufferized");
480 }
481
482 return success();
483 }
484
getPartialBufferizationOptions()485 BufferizationOptions bufferization::getPartialBufferizationOptions() {
486 BufferizationOptions options;
487 options.allowUnknownOps = true;
488 options.createDeallocs = false;
489 options.enforceAliasingInvariants = false;
490 options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
491 const BufferizationOptions &options) {
492 return getMemRefTypeWithStaticIdentityLayout(
493 value.getType().cast<TensorType>(), memorySpace);
494 };
495 options.opFilter.allowDialect<BufferizationDialect>();
496 return options;
497 }
498