1 //===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
11 
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 
14 namespace mlir {
15 class RewritePatternSet;
16 namespace vector {
17 
18 struct WarpExecuteOnLane0LoweringOptions {
19   /// Lamdba function to let users allocate memory needed for the lowering of
20   /// WarpExecuteOnLane0Op.
21   /// The function needs to return an allocation that the lowering can use as
22   /// temporary memory. The allocation needs to match the shape of the type (the
23   /// type may be VectorType or a scalar) and be availble for the current warp.
24   /// If there are several warps running in parallel the allocation needs to be
25   /// split so that each warp has its own allocation.
26   using WarpAllocationFn =
27       std::function<Value(Location, OpBuilder &, WarpExecuteOnLane0Op, Type)>;
28   WarpAllocationFn warpAllocationFn = nullptr;
29 
30   /// Lamdba function to let user emit operation to syncronize all the thread
31   /// within a warp. After this operation all the threads can see any memory
32   /// written before the operation.
33   using WarpSyncronizationFn =
34       std::function<void(Location, OpBuilder &, WarpExecuteOnLane0Op)>;
35   WarpSyncronizationFn warpSyncronizationFn = nullptr;
36 };
37 
38 void populateWarpExecuteOnLane0OpToScfForPattern(
39     RewritePatternSet &patterns,
40     const WarpExecuteOnLane0LoweringOptions &options);
41 
42 using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
43 
44 /// Distribute transfer_write ops based on the affine map returned by
45 /// `distributionMapFn`.
46 /// Example:
47 /// ```
48 /// %0 = vector.warp_execute_on_lane_0(%id){
49 ///   ...
50 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
51 ///   vector.yield
52 /// }
53 /// ```
54 /// To
55 /// ```
56 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
57 ///   ...
58 ///   vector.yield %v : vector<32xf32>
59 /// }
60 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
61 void populateDistributeTransferWriteOpPatterns(
62     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn);
63 
64 /// Move scalar operations with no dependency on the warp op outside of the
65 /// region.
66 void moveScalarUniformCode(WarpExecuteOnLane0Op op);
67 
68 /// Collect patterns to propagate warp distribution.
69 void populatePropagateWarpVectorDistributionPatterns(
70     RewritePatternSet &pattern);
71 
72 /// Lambda signature to compute a reduction of a distributed value for the given
73 /// reduction kind and size.
74 using DistributedReductionFn =
75     std::function<Value(Location, OpBuilder &, Value, CombiningKind, uint32_t)>;
76 
77 /// Collect patterns to distribute vector reduction ops using given lamdba to
78 /// distribute reduction op.
79 void populateDistributeReduction(RewritePatternSet &pattern,
80                                  DistributedReductionFn distributedReductionFn);
81 
82 } // namespace vector
83 } // namespace mlir
84 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
85