1 //===- VectorToSCF.h - Convert vector to SCF dialect ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ 10 #define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ 11 12 #include "mlir/IR/PatternMatch.h" 13 14 namespace mlir { 15 class MLIRContext; 16 class Pass; 17 class RewritePatternSet; 18 19 /// When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, 20 /// a temporary buffer is created through which individual (N-1)-d vector are 21 /// staged. This pattern can be applied multiple time, until the transfer op 22 /// is 1-d. 23 /// This is consistent with the lack of an LLVM instruction to dynamically 24 /// index into an aggregate (see the Vector dialect lowering to LLVM deep dive). 25 /// 26 /// An instruction such as: 27 /// ``` 28 /// vector.transfer_write %vec, %A[%a, %b, %c] : 29 /// vector<9x17x15xf32>, memref<?x?x?xf32> 30 /// ``` 31 /// Lowers to pseudo-IR resembling (unpacking one dimension): 32 /// ``` 33 /// %0 = alloca() : memref<vector<9x17x15xf32>> 34 /// store %vec, %0[] : memref<vector<9x17x15xf32>> 35 /// %1 = vector.type_cast %0 : 36 /// memref<vector<9x17x15xf32>> to memref<9xvector<17x15xf32>> 37 /// affine.for %I = 0 to 9 { 38 /// %dim = dim %A, 0 : memref<?x?x?xf32> 39 /// %add = affine.apply %I + %a 40 /// %cmp = arith.cmpi "slt", %add, %dim : index 41 /// scf.if %cmp { 42 /// %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>> 43 /// vector.transfer_write %vec_2d, %A[%add, %b, %c] : 44 /// vector<17x15xf32>, memref<?x?x?xf32> 45 /// ``` 46 /// 47 /// When applying the pattern a second time, the existing alloca() operation 48 /// is reused and only a second vector.type_cast is added. 49 struct VectorTransferToSCFOptions { 50 /// Minimal rank to which vector transfer are lowered. 51 unsigned targetRank = 1; setTargetRankVectorTransferToSCFOptions52 VectorTransferToSCFOptions &setTargetRank(unsigned r) { 53 targetRank = r; 54 return *this; 55 } 56 /// 57 bool lowerPermutationMaps = false; 58 VectorTransferToSCFOptions &enableLowerPermutationMaps(bool l = true) { 59 lowerPermutationMaps = l; 60 return *this; 61 } 62 /// Allows vector transfers that operated on tensors to be lowered (this is an 63 /// uncommon alternative). 64 bool lowerTensors = false; 65 VectorTransferToSCFOptions &enableLowerTensors(bool l = true) { 66 lowerTensors = l; 67 return *this; 68 } 69 /// Triggers full unrolling (vs iterating with a loop) during transfer to scf. 70 bool unroll = false; 71 VectorTransferToSCFOptions &enableFullUnroll(bool u = true) { 72 unroll = u; 73 return *this; 74 } 75 }; 76 77 /// Collect a set of patterns to convert from the Vector dialect to SCF + func. 78 void populateVectorToSCFConversionPatterns( 79 RewritePatternSet &patterns, 80 const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); 81 82 /// Create a pass to convert a subset of vector ops to SCF. 83 std::unique_ptr<Pass> createConvertVectorToSCFPass( 84 const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); 85 86 } // namespace mlir 87 88 #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ 89