1 //===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
10 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
11 
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Support/LLVM.h"
15 #include "llvm/ADT/SetVector.h"
16 
17 namespace mlir {
18 class OpBuilder;
19 
20 namespace bufferization {
21 
22 class AnalysisState;
23 class BufferizableOpInterface;
24 struct DialectAnalysisState;
25 
26 class OpFilter {
27 public:
28   /// An op filter entry. Filters can be used to specify which ops should be
29   /// processed by the bufferization.
30   struct Entry {
31     /// If the filter function evaluates to `true`, the filter matches.
32     using FilterFn = std::function<bool(Operation *)>;
33 
34     /// Filter type: A filter can either be a DENY filter or an ALLOW filter.
35     enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
36 
37     FilterFn fn;
38     FilterType type;
39   };
40 
41   /// Return whether the op is allowed or not.
42   ///
43   /// If the filter does not have an ALLOW rule, ops are allowed by default,
44   /// unless they are explicitly marked as DENY. If the filter has at least one
45   /// ALLOW rule, ops are denied by default and only allowed if they match
46   /// an ALLOW rule and no DENY rule.
47   bool isOpAllowed(Operation *op) const;
48 
49   /// Allow the given dialects.
50   ///
51   /// This function adds one or multiple ALLOW entries.
52   template <typename... DialectTs>
allowDialect()53   void allowDialect() {
54     // The following expands a call to allowDialectImpl for each dialect
55     // in 'DialectTs'. This magic is necessary due to a limitation in the places
56     // that a parameter pack can be expanded in c++11.
57     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
58     (void)std::initializer_list<int>{0, (allowDialectImpl<DialectTs>(), 0)...};
59   }
60 
61   /// Deny the given dialects.
62   ///
63   /// This function adds one or multiple DENY entries.
64   template <typename... DialectTs>
denyDialect()65   void denyDialect() {
66     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
67     (void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
68   }
69 
70   /// Allow the given dialect.
71   ///
72   /// This function adds an ALLOW entry.
allowDialect(StringRef dialectNamespace)73   void allowDialect(StringRef dialectNamespace) {
74     Entry::FilterFn filterFn = [=](Operation *op) {
75       return op->getDialect()->getNamespace() == dialectNamespace;
76     };
77     entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
78   }
79 
80   /// Allow the given ops.
81   ///
82   /// This function adds one or multiple ALLOW entries.
83   template <typename... OpTys>
allowOperation()84   void allowOperation() {
85     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
86     (void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
87   }
88 
89   /// Deny the given ops.
90   ///
91   /// This function adds one or multiple DENY entries.
92   template <typename... OpTys>
denyOperation()93   void denyOperation() {
94     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
95     (void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
96   }
97 
98   /// Allow the given op.
99   ///
100   /// This function adds an ALLOW entry.
allowOperation(StringRef opName)101   void allowOperation(StringRef opName) {
102     Entry::FilterFn filterFn = [=](Operation *op) {
103       return op->getName().getStringRef() == opName;
104     };
105     allowOperation(filterFn);
106   }
107 
108   /// Deny the given op.
109   ///
110   /// This function adds a DENY entry.
denyOperation(StringRef opName)111   void denyOperation(StringRef opName) {
112     Entry::FilterFn filterFn = [=](Operation *op) {
113       return op->getName().getStringRef() == opName;
114     };
115     denyOperation(filterFn);
116   }
117 
118   /// Allow ops that are matched by `fn`.
119   ///
120   /// This function adds an ALLOW entry.
allowOperation(Entry::FilterFn fn)121   void allowOperation(Entry::FilterFn fn) {
122     entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
123   }
124 
125   /// Deny ops that are matched by `fn`.
126   ///
127   /// This function adds a DENY entry.
denyOperation(Entry::FilterFn fn)128   void denyOperation(Entry::FilterFn fn) {
129     entries.push_back(Entry{fn, Entry::FilterType::DENY});
130   }
131 
132 private:
133   /// Return `true` if the filter has at least one ALLOW rule.
hasAllowRule()134   bool hasAllowRule() const {
135     for (const Entry &e : entries)
136       if (e.type == Entry::FilterType::ALLOW)
137         return true;
138     return false;
139   }
140 
141   /// Allow a dialect.
142   template <typename DialectT>
allowDialectImpl()143   void allowDialectImpl() {
144     allowDialect(DialectT::getDialectNamespace());
145   }
146 
147   /// Deny a dialect.
148   template <typename DialectT>
denyDialectImpl()149   void denyDialectImpl() {
150     denyDialect(DialectT::getDialectNamespace());
151   }
152 
153   /// Allow an op.
154   template <typename OpTy>
allowOperationImpl()155   void allowOperationImpl() {
156     allowOperation(OpTy::getOperationName());
157   }
158 
159   /// Deny an op.
160   template <typename OpTy>
denyOperationImpl()161   void denyOperationImpl() {
162     denyOperation(OpTy::getOperationName());
163   }
164 
165   /// A list of filter entries that determine whether an op should be allowed or
166   /// denied. If the filter has an ALLOW rule, only ops that are allowed and not
167   /// denied are allowed. If the filter does not have an ALLOW rule, only ops
168   /// that are not denied are allowed.
169   SmallVector<Entry> entries;
170 };
171 
172 /// Options for BufferizableOpInterface-based bufferization.
173 struct BufferizationOptions {
174   /// Allocator function: Generate a memref allocation with the given type,
175   /// dynamic extents and alignment.
176   using AllocationFn = std::function<FailureOr<Value>(
177       OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
178   /// Deallocator function: Deallocate a buffer that was allocated with
179   /// AllocatorFn.
180   using DeallocationFn =
181       std::function<LogicalResult(OpBuilder &, Location, Value)>;
182   /// Memcpy function: Generate a memcpy between two buffers.
183   using MemCpyFn =
184       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
185   /// Initializer function for analysis state.
186   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
187   /// Initializer function for dialect-specific analysis state.
188   using DialectStateInitFn =
189       std::function<std::unique_ptr<DialectAnalysisState>()>;
190   /// Tensor -> MemRef type converter.
191   /// Parameters: Value, memory space, bufferization options
192   using UnknownTypeConverterFn = std::function<BaseMemRefType(
193       Value, unsigned, const BufferizationOptions &)>;
194 
195   enum class LayoutMapOption : int8_t {
196     InferLayoutMap = 0,
197     IdentityLayoutMap = 1,
198     FullyDynamicLayoutMap = 2
199   };
200 
201   BufferizationOptions();
202 
203   /// Try to cast the given op to BufferizableOpInterface if the op is allow
204   /// listed.
205   BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
206 
207   /// Try to cast the given value to BufferizableOpInterface if the op is allow
208   /// listed.
209   BufferizableOpInterface dynCastBufferizableOp(Value value) const;
210 
211   /// A filter that specifies which ops should be bufferized and which ops
212   /// should be ignored.
213   OpFilter opFilter;
214 
215   /// Return `true` if the given op should be bufferized.
216   bool isOpAllowed(Operation *op) const;
217 
218   /// Helper functions for allocation, deallocation, memory copying.
219   Optional<AllocationFn> allocationFn;
220   Optional<DeallocationFn> deallocationFn;
221   Optional<MemCpyFn> memCpyFn;
222 
223   /// Create a memref allocation with the given type and dynamic extents.
224   FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
225                                ValueRange dynShape) const;
226 
227   /// Creates a memref deallocation. The given memref buffer must have been
228   /// allocated using `createAlloc`.
229   LogicalResult createDealloc(OpBuilder &b, Location loc,
230                               Value allocatedBuffer) const;
231 
232   /// Creates a memcpy between two given buffers.
233   LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from,
234                              Value to) const;
235 
236   /// Specifies whether not bufferizable ops are allowed in the input. If so,
237   /// bufferization.to_memref and bufferization.to_tensor ops are inserted at
238   /// the boundaries.
239   bool allowUnknownOps = false;
240 
241   /// Specifies whether function boundaries (ops in the func dialect) should be
242   /// bufferized or not.
243   bool bufferizeFunctionBoundaries = false;
244 
245   /// The default memory space that should be used when it cannot be inferred
246   /// from the context. If no default memory space is specified, bufferization
247   /// fails when the memory space cannot be inferred at any point.
248   Optional<unsigned> defaultMemorySpace = 0;
249 
250   /// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for).
251   /// If this flag is set to `false`, those invariants are no longer enforced
252   /// with buffer copies.
253   ///
254   /// Note: Deactivating this flag can lead to incorrect bufferization results
255   /// when used incorrectly. This flag is useful with
256   /// `AlwaysCopyAnalysisState` which bufferizes all writing tensor
257   /// OpOperands out-of-place.
258   bool enforceAliasingInvariants = true;
259 
260   /// This flag controls buffer types on function signatures.
261   ///
262   /// * InferLayoutMap: All function parameter types have a fully dynamic layout
263   ///   map, but function result types are inferred from the body of the
264   ///   function.
265   /// * FullyDynamicLayoutMap: All function parameter types and result types
266   ///   have a fully dynamic layout map. This option is most efficient because
267   ///   any layout map can be casted to a fully dynamic one.
268   /// * IdentityLayoutMap: All function parameter types and result types have a
269   ///   static identity layout (i.e., no layout map). This option may introduce
270   ///   additional buffer allocs and copies because layout maps cannot be casted
271   ///   away.
272   ///
273   /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
274   ///
275   /// Note: Inferred layout maps may not be desireable when interacting with
276   /// external functions, because the generated function signatures will be less
277   /// predictable.
278   LayoutMapOption functionBoundaryTypeConversion =
279       LayoutMapOption::InferLayoutMap;
280 
281   /// Type converter from tensors to memrefs. This type converter is used if no
282   /// memref type could be inferred during bufferization. By default, a type
283   /// converter that returns a memref type with a fully dynamic layout map is
284   /// used.
285   UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
286 
287   /// Specifies whether dealloc ops should be generated along with alloc ops. If
288   /// not, new memory allocations will leak.
289   bool createDeallocs = true;
290 
291   /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
292   /// Should be used only with `testAnalysisOnly = true`.
293   unsigned analysisFuzzerSeed = 0;
294 
295   /// If set to `true`, does not modify the IR apart from adding attributes (for
296   /// checking the results of the analysis) and post analysis steps.
297   bool testAnalysisOnly = false;
298 
299   /// If set to `true`, the IR is annotated with details about RaW conflicts.
300   /// For debugging only. Should be used together with `testAnalysisOnly`.
301   bool printConflicts = false;
302 
303   /// Buffer alignment for new memory allocations.
304   unsigned int bufferAlignment = 128;
305 
306   /// Initializer functions for analysis state. These can be used to
307   /// initialize dialect-specific analysis state.
308   SmallVector<AnalysisStateInitFn> stateInitializers;
309 
310   /// Add a analysis state initializer that initializes the specified
311   /// dialect-specific analysis state.
312   void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
313 };
314 
315 /// Specify fine-grain relationship between buffers to enable more analysis.
316 enum class BufferRelation {
317   None,
318   // TODO: ResultContainsOperand,
319   // TODO: OperandContainsResult,
320   Equivalent
321 };
322 
323 /// Return `true` if the given value is a BlockArgument of a func::FuncOp.
324 bool isFunctionArgument(Value value);
325 
326 /// Dialect-specific analysis state. Analysis/bufferization information
327 /// that is specific to ops from a certain dialect can be stored in derived
328 /// variants of this struct.
329 struct DialectAnalysisState {
330   DialectAnalysisState() = default;
331 
332   virtual ~DialectAnalysisState() = default;
333 
334   // Copying state is forbidden. Always pass as reference.
335   DialectAnalysisState(const DialectAnalysisState &) = delete;
336 };
337 
338 /// AnalysisState provides a variety of helper functions for dealing with
339 /// tensor values.
340 class AnalysisState {
341 public:
342   /// Determine which OpOperand* will alias with `result` if the op is
343   /// bufferized in place. Return an empty vector if the op is not bufferizable.
344   SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
345 
346   /// Determine which OpResult will alias with `opOperand` if the op is
347   /// bufferized in place. Return an empty vector if the op is not bufferizable.
348   SmallVector<OpResult> getAliasingOpResult(OpOperand &opOperand) const;
349 
350   /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
351   /// the op is not bufferizable.
352   bool bufferizesToMemoryRead(OpOperand &opOperand) const;
353 
354   /// Return true if `opOperand` bufferizes to a memory write. Return true` if
355   /// the op is not bufferizable.
356   bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
357 
358   /// Return true if `opOperand` does neither read nor write but bufferizes to
359   /// an alias. Return false if the op is not bufferizable.
360   bool bufferizesToAliasOnly(OpOperand &opOperand) const;
361 
362   /// Return true if a copy can always be avoided when allocating a new tensor
363   /// for the given OpOperand.
364   bool canOmitTensorCopy(OpOperand &opOperand) const;
365 
366   /// Return true if the given value is read by an op that bufferizes to a
367   /// memory read. Also takes into account ops that create an alias but do not
368   /// read by themselves (e.g., ExtractSliceOp).
369   bool isValueRead(Value value) const;
370 
371   /// Starting from `value`, follow the use-def chain in reverse, always
372   /// selecting the aliasing OpOperands. Find and return Values for which
373   /// `condition` evaluates to true. OpOperands of such matching Values are not
374   /// traversed any further.
375   ///
376   /// When reaching the end of a chain (BlockArgument or Value without aliasing
377   /// OpOperands), also return the last Value of that chain.
378   ///
379   /// Example:
380   ///
381   ///                               8
382   ///                               |
383   ///   6*         7*         +-----+----+
384   ///   |          |          |          |
385   ///   2*         3          4*         5
386   ///   |          |          |          |
387   ///   +----------+----------+----------+
388   ///              |
389   ///              1
390   ///
391   /// In the above example, Values with a star satisfy the condition. When
392   /// starting the traversal from Value 1, the resulting SetVector is:
393   /// { 2, 7, 8, 5 }
394   SetVector<Value> findValueInReverseUseDefChain(
395       Value value, llvm::function_ref<bool(Value)> condition) const;
396 
397   /// Find the Values of the last preceding write of a given Value.
398   ///
399   /// Note: Unknown ops are handled conservatively and assumed to be writes.
400   /// Furthermore, BlockArguments are also assumed to be writes. There is no
401   /// analysis across block boundaries.
402   ///
403   /// Note: When reaching an end of the reverse SSA use-def chain, that value
404   /// is returned regardless of whether it is a memory write or not.
405   SetVector<Value> findLastPrecedingWrite(Value value) const;
406 
407   /// Return `true` if the given OpResult has been decided to bufferize inplace.
408   virtual bool isInPlace(OpOperand &opOperand) const;
409 
410   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
411   virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const;
412 
413   /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
414   virtual bool areAliasingBufferizedValues(Value v1, Value v2) const;
415 
416   /// Return `true` if the given tensor has undefined contents.
417   virtual bool hasUndefinedContents(OpOperand *opOperand) const;
418 
419   /// Return true if the given tensor (or an aliasing tensor) is yielded from
420   /// the containing block. Also include all aliasing tensors in the same block.
421   ///
422   /// Note: In the absence of an analysis, an implementation may return true for
423   /// any given tensor.
424   virtual bool isTensorYielded(Value tensor) const;
425 
426   /// Return `true` if the given dialect state exists.
hasDialectState(StringRef name)427   bool hasDialectState(StringRef name) const {
428     auto it = dialectState.find(name);
429     return it != dialectState.end();
430   }
431 
432   /// Return dialect-specific bufferization state.
433   template <typename StateT>
getDialectState(StringRef name)434   Optional<const StateT *> getDialectState(StringRef name) const {
435     auto it = dialectState.find(name);
436     if (it == dialectState.end())
437       return None;
438     return static_cast<const StateT *>(it->getSecond().get());
439   }
440 
441   /// Return dialect-specific analysis state or create one if none exists.
442   template <typename StateT>
getOrCreateDialectState(StringRef name)443   StateT &getOrCreateDialectState(StringRef name) {
444     // Create state if it does not exist yet.
445     if (!hasDialectState(name))
446       dialectState[name] = std::make_unique<StateT>();
447     return static_cast<StateT &>(*dialectState[name]);
448   }
449 
insertDialectState(StringRef name,std::unique_ptr<DialectAnalysisState> state)450   void insertDialectState(StringRef name,
451                           std::unique_ptr<DialectAnalysisState> state) {
452     assert(!dialectState.count(name) && "dialect state already initialized");
453     dialectState[name] = std::move(state);
454   }
455 
456   /// Return a reference to the BufferizationOptions.
getOptions()457   const BufferizationOptions &getOptions() const { return options; }
458 
459   explicit AnalysisState(const BufferizationOptions &options);
460 
461   // AnalysisState should be passed as a reference.
462   AnalysisState(const AnalysisState &) = delete;
463 
464   virtual ~AnalysisState() = default;
465 
466 private:
467   /// Dialect-specific analysis state.
468   DenseMap<StringRef, std::unique_ptr<DialectAnalysisState>> dialectState;
469 
470   /// A reference to current bufferization options.
471   const BufferizationOptions &options;
472 };
473 
474 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
475 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
476 /// undefined contents is allocated.
477 FailureOr<Value>
478 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
479                              bool escape, const BufferizationOptions &options,
480                              bool copy = true);
481 
482 /// Return `true` if the allocation of the given op is guaranteed to not escape
483 /// the containing block.
484 bool allocationDoesNotEscape(OpResult opResult);
485 
486 /// Lookup the buffer for the given value. If the value was not bufferized
487 /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
488 /// from which the memref operand is returned.
489 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
490                            const BufferizationOptions &options);
491 
492 /// Return the buffer type for a given Value (tensor) after bufferization.
493 ///
494 /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
495 /// This function should only be used if `getBuffer` cannot be used.
496 FailureOr<BaseMemRefType> getBufferType(Value value,
497                                         const BufferizationOptions &options);
498 
499 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
500 /// must be replaced with memref values.
501 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
502                                    ValueRange values);
503 
504 /// Replace an op with a new op. The new op must have the same number of
505 /// results as the replaced op. The new op may not return any tensor values.
506 template <typename OpTy, typename... Args>
replaceOpWithNewBufferizedOp(RewriterBase & rewriter,Operation * op,Args &&...args)507 OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
508                                   Args &&...args) {
509   auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
510   replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
511   return newOp;
512 }
513 
514 /// Return `true` if the buffer of given OpResult should be deallocated. This
515 /// function should be called during `BufferizableOpInterface::bufferize`
516 /// implementations that allocate a new buffer for the given OpResult.
517 bool shouldDeallocateOpResult(OpResult opResult,
518                               const BufferizationOptions &options);
519 
520 /// Return a MemRefType to which the type of the given value can be bufferized.
521 ///
522 /// If possible, op bufferization implementations should not use this function
523 /// and instead infer precise memref types for tensor results by themselves.
524 ///
525 /// Unless a layout map was specified, `options.unknownTypeConverterFn`
526 /// determines what kind of layout map will be used. For best composability
527 /// (without copies), the fully dynamic layout map is used by default.
528 ///
529 /// Note: Canonicalization patterns could clean up layout maps and infer more
530 /// precise layout maps after bufferization. However, many possible
531 /// canonicalizations are currently not implemented.
532 BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
533                              MemRefLayoutAttrInterface layout = {},
534                              unsigned memorySpace = 0);
535 
536 /// Return a MemRef type with fully dynamic layout. If the given tensor type
537 /// is unranked, return an unranked MemRef type.
538 BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
539                                                    unsigned memorySpace = 0);
540 
541 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
542 /// the given tensor type is unranked, return an unranked MemRef type.
543 BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
544                                                      unsigned memorySpace = 0);
545 
546 } // namespace bufferization
547 } // namespace mlir
548 
549 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
550 
551 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
552