1 //===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ 11 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 14 namespace mlir { 15 class RewritePatternSet; 16 namespace vector { 17 18 struct WarpExecuteOnLane0LoweringOptions { 19 /// Lamdba function to let users allocate memory needed for the lowering of 20 /// WarpExecuteOnLane0Op. 21 /// The function needs to return an allocation that the lowering can use as 22 /// temporary memory. The allocation needs to match the shape of the type (the 23 /// type may be VectorType or a scalar) and be availble for the current warp. 24 /// If there are several warps running in parallel the allocation needs to be 25 /// split so that each warp has its own allocation. 26 using WarpAllocationFn = 27 std::function<Value(Location, OpBuilder &, WarpExecuteOnLane0Op, Type)>; 28 WarpAllocationFn warpAllocationFn = nullptr; 29 30 /// Lamdba function to let user emit operation to syncronize all the thread 31 /// within a warp. After this operation all the threads can see any memory 32 /// written before the operation. 33 using WarpSyncronizationFn = 34 std::function<void(Location, OpBuilder &, WarpExecuteOnLane0Op)>; 35 WarpSyncronizationFn warpSyncronizationFn = nullptr; 36 }; 37 38 void populateWarpExecuteOnLane0OpToScfForPattern( 39 RewritePatternSet &patterns, 40 const WarpExecuteOnLane0LoweringOptions &options); 41 42 using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>; 43 44 /// Distribute transfer_write ops based on the affine map returned by 45 /// `distributionMapFn`. 46 /// Example: 47 /// ``` 48 /// %0 = vector.warp_execute_on_lane_0(%id){ 49 /// ... 50 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 51 /// vector.yield 52 /// } 53 /// ``` 54 /// To 55 /// ``` 56 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 57 /// ... 58 /// vector.yield %v : vector<32xf32> 59 /// } 60 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 61 void populateDistributeTransferWriteOpPatterns( 62 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn); 63 64 /// Move scalar operations with no dependency on the warp op outside of the 65 /// region. 66 void moveScalarUniformCode(WarpExecuteOnLane0Op op); 67 68 /// Collect patterns to propagate warp distribution. 69 void populatePropagateWarpVectorDistributionPatterns( 70 RewritePatternSet &pattern); 71 72 /// Lambda signature to compute a reduction of a distributed value for the given 73 /// reduction kind and size. 74 using DistributedReductionFn = 75 std::function<Value(Location, OpBuilder &, Value, CombiningKind, uint32_t)>; 76 77 /// Collect patterns to distribute vector reduction ops using given lamdba to 78 /// distribute reduction op. 79 void populateDistributeReduction(RewritePatternSet &pattern, 80 DistributedReductionFn distributedReductionFn); 81 82 } // namespace vector 83 } // namespace mlir 84 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ 85