1 //===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
10 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
11
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Support/LLVM.h"
15 #include "llvm/ADT/SetVector.h"
16
17 namespace mlir {
18 class OpBuilder;
19
20 namespace bufferization {
21
22 class AnalysisState;
23 class BufferizableOpInterface;
24 struct DialectAnalysisState;
25
26 class OpFilter {
27 public:
28 /// An op filter entry. Filters can be used to specify which ops should be
29 /// processed by the bufferization.
30 struct Entry {
31 /// If the filter function evaluates to `true`, the filter matches.
32 using FilterFn = std::function<bool(Operation *)>;
33
34 /// Filter type: A filter can either be a DENY filter or an ALLOW filter.
35 enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
36
37 FilterFn fn;
38 FilterType type;
39 };
40
41 /// Return whether the op is allowed or not.
42 ///
43 /// If the filter does not have an ALLOW rule, ops are allowed by default,
44 /// unless they are explicitly marked as DENY. If the filter has at least one
45 /// ALLOW rule, ops are denied by default and only allowed if they match
46 /// an ALLOW rule and no DENY rule.
47 bool isOpAllowed(Operation *op) const;
48
49 /// Allow the given dialects.
50 ///
51 /// This function adds one or multiple ALLOW entries.
52 template <typename... DialectTs>
allowDialect()53 void allowDialect() {
54 // The following expands a call to allowDialectImpl for each dialect
55 // in 'DialectTs'. This magic is necessary due to a limitation in the places
56 // that a parameter pack can be expanded in c++11.
57 // FIXME: In c++17 this can be simplified by using 'fold expressions'.
58 (void)std::initializer_list<int>{0, (allowDialectImpl<DialectTs>(), 0)...};
59 }
60
61 /// Deny the given dialects.
62 ///
63 /// This function adds one or multiple DENY entries.
64 template <typename... DialectTs>
denyDialect()65 void denyDialect() {
66 // FIXME: In c++17 this can be simplified by using 'fold expressions'.
67 (void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
68 }
69
70 /// Allow the given dialect.
71 ///
72 /// This function adds an ALLOW entry.
allowDialect(StringRef dialectNamespace)73 void allowDialect(StringRef dialectNamespace) {
74 Entry::FilterFn filterFn = [=](Operation *op) {
75 return op->getDialect()->getNamespace() == dialectNamespace;
76 };
77 entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
78 }
79
80 /// Allow the given ops.
81 ///
82 /// This function adds one or multiple ALLOW entries.
83 template <typename... OpTys>
allowOperation()84 void allowOperation() {
85 // FIXME: In c++17 this can be simplified by using 'fold expressions'.
86 (void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
87 }
88
89 /// Deny the given ops.
90 ///
91 /// This function adds one or multiple DENY entries.
92 template <typename... OpTys>
denyOperation()93 void denyOperation() {
94 // FIXME: In c++17 this can be simplified by using 'fold expressions'.
95 (void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
96 }
97
98 /// Allow the given op.
99 ///
100 /// This function adds an ALLOW entry.
allowOperation(StringRef opName)101 void allowOperation(StringRef opName) {
102 Entry::FilterFn filterFn = [=](Operation *op) {
103 return op->getName().getStringRef() == opName;
104 };
105 allowOperation(filterFn);
106 }
107
108 /// Deny the given op.
109 ///
110 /// This function adds a DENY entry.
denyOperation(StringRef opName)111 void denyOperation(StringRef opName) {
112 Entry::FilterFn filterFn = [=](Operation *op) {
113 return op->getName().getStringRef() == opName;
114 };
115 denyOperation(filterFn);
116 }
117
118 /// Allow ops that are matched by `fn`.
119 ///
120 /// This function adds an ALLOW entry.
allowOperation(Entry::FilterFn fn)121 void allowOperation(Entry::FilterFn fn) {
122 entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
123 }
124
125 /// Deny ops that are matched by `fn`.
126 ///
127 /// This function adds a DENY entry.
denyOperation(Entry::FilterFn fn)128 void denyOperation(Entry::FilterFn fn) {
129 entries.push_back(Entry{fn, Entry::FilterType::DENY});
130 }
131
132 private:
133 /// Return `true` if the filter has at least one ALLOW rule.
hasAllowRule()134 bool hasAllowRule() const {
135 for (const Entry &e : entries)
136 if (e.type == Entry::FilterType::ALLOW)
137 return true;
138 return false;
139 }
140
141 /// Allow a dialect.
142 template <typename DialectT>
allowDialectImpl()143 void allowDialectImpl() {
144 allowDialect(DialectT::getDialectNamespace());
145 }
146
147 /// Deny a dialect.
148 template <typename DialectT>
denyDialectImpl()149 void denyDialectImpl() {
150 denyDialect(DialectT::getDialectNamespace());
151 }
152
153 /// Allow an op.
154 template <typename OpTy>
allowOperationImpl()155 void allowOperationImpl() {
156 allowOperation(OpTy::getOperationName());
157 }
158
159 /// Deny an op.
160 template <typename OpTy>
denyOperationImpl()161 void denyOperationImpl() {
162 denyOperation(OpTy::getOperationName());
163 }
164
165 /// A list of filter entries that determine whether an op should be allowed or
166 /// denied. If the filter has an ALLOW rule, only ops that are allowed and not
167 /// denied are allowed. If the filter does not have an ALLOW rule, only ops
168 /// that are not denied are allowed.
169 SmallVector<Entry> entries;
170 };
171
172 /// Options for BufferizableOpInterface-based bufferization.
173 struct BufferizationOptions {
174 /// Allocator function: Generate a memref allocation with the given type,
175 /// dynamic extents and alignment.
176 using AllocationFn = std::function<FailureOr<Value>(
177 OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
178 /// Deallocator function: Deallocate a buffer that was allocated with
179 /// AllocatorFn.
180 using DeallocationFn =
181 std::function<LogicalResult(OpBuilder &, Location, Value)>;
182 /// Memcpy function: Generate a memcpy between two buffers.
183 using MemCpyFn =
184 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
185 /// Initializer function for analysis state.
186 using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
187 /// Initializer function for dialect-specific analysis state.
188 using DialectStateInitFn =
189 std::function<std::unique_ptr<DialectAnalysisState>()>;
190 /// Tensor -> MemRef type converter.
191 /// Parameters: Value, memory space, bufferization options
192 using UnknownTypeConverterFn = std::function<BaseMemRefType(
193 Value, unsigned, const BufferizationOptions &)>;
194
195 enum class LayoutMapOption : int8_t {
196 InferLayoutMap = 0,
197 IdentityLayoutMap = 1,
198 FullyDynamicLayoutMap = 2
199 };
200
201 BufferizationOptions();
202
203 /// Try to cast the given op to BufferizableOpInterface if the op is allow
204 /// listed.
205 BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
206
207 /// Try to cast the given value to BufferizableOpInterface if the op is allow
208 /// listed.
209 BufferizableOpInterface dynCastBufferizableOp(Value value) const;
210
211 /// A filter that specifies which ops should be bufferized and which ops
212 /// should be ignored.
213 OpFilter opFilter;
214
215 /// Return `true` if the given op should be bufferized.
216 bool isOpAllowed(Operation *op) const;
217
218 /// Helper functions for allocation, deallocation, memory copying.
219 Optional<AllocationFn> allocationFn;
220 Optional<DeallocationFn> deallocationFn;
221 Optional<MemCpyFn> memCpyFn;
222
223 /// Create a memref allocation with the given type and dynamic extents.
224 FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
225 ValueRange dynShape) const;
226
227 /// Creates a memref deallocation. The given memref buffer must have been
228 /// allocated using `createAlloc`.
229 LogicalResult createDealloc(OpBuilder &b, Location loc,
230 Value allocatedBuffer) const;
231
232 /// Creates a memcpy between two given buffers.
233 LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from,
234 Value to) const;
235
236 /// Specifies whether not bufferizable ops are allowed in the input. If so,
237 /// bufferization.to_memref and bufferization.to_tensor ops are inserted at
238 /// the boundaries.
239 bool allowUnknownOps = false;
240
241 /// Specifies whether function boundaries (ops in the func dialect) should be
242 /// bufferized or not.
243 bool bufferizeFunctionBoundaries = false;
244
245 /// The default memory space that should be used when it cannot be inferred
246 /// from the context. If no default memory space is specified, bufferization
247 /// fails when the memory space cannot be inferred at any point.
248 Optional<unsigned> defaultMemorySpace = 0;
249
250 /// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for).
251 /// If this flag is set to `false`, those invariants are no longer enforced
252 /// with buffer copies.
253 ///
254 /// Note: Deactivating this flag can lead to incorrect bufferization results
255 /// when used incorrectly. This flag is useful with
256 /// `AlwaysCopyAnalysisState` which bufferizes all writing tensor
257 /// OpOperands out-of-place.
258 bool enforceAliasingInvariants = true;
259
260 /// This flag controls buffer types on function signatures.
261 ///
262 /// * InferLayoutMap: All function parameter types have a fully dynamic layout
263 /// map, but function result types are inferred from the body of the
264 /// function.
265 /// * FullyDynamicLayoutMap: All function parameter types and result types
266 /// have a fully dynamic layout map. This option is most efficient because
267 /// any layout map can be casted to a fully dynamic one.
268 /// * IdentityLayoutMap: All function parameter types and result types have a
269 /// static identity layout (i.e., no layout map). This option may introduce
270 /// additional buffer allocs and copies because layout maps cannot be casted
271 /// away.
272 ///
273 /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
274 ///
275 /// Note: Inferred layout maps may not be desireable when interacting with
276 /// external functions, because the generated function signatures will be less
277 /// predictable.
278 LayoutMapOption functionBoundaryTypeConversion =
279 LayoutMapOption::InferLayoutMap;
280
281 /// Type converter from tensors to memrefs. This type converter is used if no
282 /// memref type could be inferred during bufferization. By default, a type
283 /// converter that returns a memref type with a fully dynamic layout map is
284 /// used.
285 UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
286
287 /// Specifies whether dealloc ops should be generated along with alloc ops. If
288 /// not, new memory allocations will leak.
289 bool createDeallocs = true;
290
291 /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
292 /// Should be used only with `testAnalysisOnly = true`.
293 unsigned analysisFuzzerSeed = 0;
294
295 /// If set to `true`, does not modify the IR apart from adding attributes (for
296 /// checking the results of the analysis) and post analysis steps.
297 bool testAnalysisOnly = false;
298
299 /// If set to `true`, the IR is annotated with details about RaW conflicts.
300 /// For debugging only. Should be used together with `testAnalysisOnly`.
301 bool printConflicts = false;
302
303 /// Buffer alignment for new memory allocations.
304 unsigned int bufferAlignment = 128;
305
306 /// Initializer functions for analysis state. These can be used to
307 /// initialize dialect-specific analysis state.
308 SmallVector<AnalysisStateInitFn> stateInitializers;
309
310 /// Add a analysis state initializer that initializes the specified
311 /// dialect-specific analysis state.
312 void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
313 };
314
315 /// Specify fine-grain relationship between buffers to enable more analysis.
316 enum class BufferRelation {
317 None,
318 // TODO: ResultContainsOperand,
319 // TODO: OperandContainsResult,
320 Equivalent
321 };
322
323 /// Return `true` if the given value is a BlockArgument of a func::FuncOp.
324 bool isFunctionArgument(Value value);
325
326 /// Dialect-specific analysis state. Analysis/bufferization information
327 /// that is specific to ops from a certain dialect can be stored in derived
328 /// variants of this struct.
329 struct DialectAnalysisState {
330 DialectAnalysisState() = default;
331
332 virtual ~DialectAnalysisState() = default;
333
334 // Copying state is forbidden. Always pass as reference.
335 DialectAnalysisState(const DialectAnalysisState &) = delete;
336 };
337
338 /// AnalysisState provides a variety of helper functions for dealing with
339 /// tensor values.
340 class AnalysisState {
341 public:
342 /// Determine which OpOperand* will alias with `result` if the op is
343 /// bufferized in place. Return an empty vector if the op is not bufferizable.
344 SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
345
346 /// Determine which OpResult will alias with `opOperand` if the op is
347 /// bufferized in place. Return an empty vector if the op is not bufferizable.
348 SmallVector<OpResult> getAliasingOpResult(OpOperand &opOperand) const;
349
350 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
351 /// the op is not bufferizable.
352 bool bufferizesToMemoryRead(OpOperand &opOperand) const;
353
354 /// Return true if `opOperand` bufferizes to a memory write. Return true` if
355 /// the op is not bufferizable.
356 bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
357
358 /// Return true if `opOperand` does neither read nor write but bufferizes to
359 /// an alias. Return false if the op is not bufferizable.
360 bool bufferizesToAliasOnly(OpOperand &opOperand) const;
361
362 /// Return true if a copy can always be avoided when allocating a new tensor
363 /// for the given OpOperand.
364 bool canOmitTensorCopy(OpOperand &opOperand) const;
365
366 /// Return true if the given value is read by an op that bufferizes to a
367 /// memory read. Also takes into account ops that create an alias but do not
368 /// read by themselves (e.g., ExtractSliceOp).
369 bool isValueRead(Value value) const;
370
371 /// Starting from `value`, follow the use-def chain in reverse, always
372 /// selecting the aliasing OpOperands. Find and return Values for which
373 /// `condition` evaluates to true. OpOperands of such matching Values are not
374 /// traversed any further.
375 ///
376 /// When reaching the end of a chain (BlockArgument or Value without aliasing
377 /// OpOperands), also return the last Value of that chain.
378 ///
379 /// Example:
380 ///
381 /// 8
382 /// |
383 /// 6* 7* +-----+----+
384 /// | | | |
385 /// 2* 3 4* 5
386 /// | | | |
387 /// +----------+----------+----------+
388 /// |
389 /// 1
390 ///
391 /// In the above example, Values with a star satisfy the condition. When
392 /// starting the traversal from Value 1, the resulting SetVector is:
393 /// { 2, 7, 8, 5 }
394 SetVector<Value> findValueInReverseUseDefChain(
395 Value value, llvm::function_ref<bool(Value)> condition) const;
396
397 /// Find the Values of the last preceding write of a given Value.
398 ///
399 /// Note: Unknown ops are handled conservatively and assumed to be writes.
400 /// Furthermore, BlockArguments are also assumed to be writes. There is no
401 /// analysis across block boundaries.
402 ///
403 /// Note: When reaching an end of the reverse SSA use-def chain, that value
404 /// is returned regardless of whether it is a memory write or not.
405 SetVector<Value> findLastPrecedingWrite(Value value) const;
406
407 /// Return `true` if the given OpResult has been decided to bufferize inplace.
408 virtual bool isInPlace(OpOperand &opOperand) const;
409
410 /// Return true if `v1` and `v2` bufferize to equivalent buffers.
411 virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const;
412
413 /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
414 virtual bool areAliasingBufferizedValues(Value v1, Value v2) const;
415
416 /// Return `true` if the given tensor has undefined contents.
417 virtual bool hasUndefinedContents(OpOperand *opOperand) const;
418
419 /// Return true if the given tensor (or an aliasing tensor) is yielded from
420 /// the containing block. Also include all aliasing tensors in the same block.
421 ///
422 /// Note: In the absence of an analysis, an implementation may return true for
423 /// any given tensor.
424 virtual bool isTensorYielded(Value tensor) const;
425
426 /// Return `true` if the given dialect state exists.
hasDialectState(StringRef name)427 bool hasDialectState(StringRef name) const {
428 auto it = dialectState.find(name);
429 return it != dialectState.end();
430 }
431
432 /// Return dialect-specific bufferization state.
433 template <typename StateT>
getDialectState(StringRef name)434 Optional<const StateT *> getDialectState(StringRef name) const {
435 auto it = dialectState.find(name);
436 if (it == dialectState.end())
437 return None;
438 return static_cast<const StateT *>(it->getSecond().get());
439 }
440
441 /// Return dialect-specific analysis state or create one if none exists.
442 template <typename StateT>
getOrCreateDialectState(StringRef name)443 StateT &getOrCreateDialectState(StringRef name) {
444 // Create state if it does not exist yet.
445 if (!hasDialectState(name))
446 dialectState[name] = std::make_unique<StateT>();
447 return static_cast<StateT &>(*dialectState[name]);
448 }
449
insertDialectState(StringRef name,std::unique_ptr<DialectAnalysisState> state)450 void insertDialectState(StringRef name,
451 std::unique_ptr<DialectAnalysisState> state) {
452 assert(!dialectState.count(name) && "dialect state already initialized");
453 dialectState[name] = std::move(state);
454 }
455
456 /// Return a reference to the BufferizationOptions.
getOptions()457 const BufferizationOptions &getOptions() const { return options; }
458
459 explicit AnalysisState(const BufferizationOptions &options);
460
461 // AnalysisState should be passed as a reference.
462 AnalysisState(const AnalysisState &) = delete;
463
464 virtual ~AnalysisState() = default;
465
466 private:
467 /// Dialect-specific analysis state.
468 DenseMap<StringRef, std::unique_ptr<DialectAnalysisState>> dialectState;
469
470 /// A reference to current bufferization options.
471 const BufferizationOptions &options;
472 };
473
474 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
475 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
476 /// undefined contents is allocated.
477 FailureOr<Value>
478 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
479 bool escape, const BufferizationOptions &options,
480 bool copy = true);
481
482 /// Return `true` if the allocation of the given op is guaranteed to not escape
483 /// the containing block.
484 bool allocationDoesNotEscape(OpResult opResult);
485
486 /// Lookup the buffer for the given value. If the value was not bufferized
487 /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
488 /// from which the memref operand is returned.
489 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
490 const BufferizationOptions &options);
491
492 /// Return the buffer type for a given Value (tensor) after bufferization.
493 ///
494 /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
495 /// This function should only be used if `getBuffer` cannot be used.
496 FailureOr<BaseMemRefType> getBufferType(Value value,
497 const BufferizationOptions &options);
498
499 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
500 /// must be replaced with memref values.
501 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
502 ValueRange values);
503
504 /// Replace an op with a new op. The new op must have the same number of
505 /// results as the replaced op. The new op may not return any tensor values.
506 template <typename OpTy, typename... Args>
replaceOpWithNewBufferizedOp(RewriterBase & rewriter,Operation * op,Args &&...args)507 OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
508 Args &&...args) {
509 auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
510 replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
511 return newOp;
512 }
513
514 /// Return `true` if the buffer of given OpResult should be deallocated. This
515 /// function should be called during `BufferizableOpInterface::bufferize`
516 /// implementations that allocate a new buffer for the given OpResult.
517 bool shouldDeallocateOpResult(OpResult opResult,
518 const BufferizationOptions &options);
519
520 /// Return a MemRefType to which the type of the given value can be bufferized.
521 ///
522 /// If possible, op bufferization implementations should not use this function
523 /// and instead infer precise memref types for tensor results by themselves.
524 ///
525 /// Unless a layout map was specified, `options.unknownTypeConverterFn`
526 /// determines what kind of layout map will be used. For best composability
527 /// (without copies), the fully dynamic layout map is used by default.
528 ///
529 /// Note: Canonicalization patterns could clean up layout maps and infer more
530 /// precise layout maps after bufferization. However, many possible
531 /// canonicalizations are currently not implemented.
532 BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
533 MemRefLayoutAttrInterface layout = {},
534 unsigned memorySpace = 0);
535
536 /// Return a MemRef type with fully dynamic layout. If the given tensor type
537 /// is unranked, return an unranked MemRef type.
538 BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
539 unsigned memorySpace = 0);
540
541 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
542 /// the given tensor type is unranked, return an unranked MemRef type.
543 BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
544 unsigned memorySpace = 0);
545
546 } // namespace bufferization
547 } // namespace mlir
548
549 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
550
551 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
552