1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H 11 12 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 14 15 namespace mlir { 16 class MLIRContext; 17 class VectorTransferOpInterface; 18 class RewritePatternSet; 19 class RewriterBase; 20 21 namespace scf { 22 class IfOp; 23 } // namespace scf 24 25 namespace vector { 26 27 //===----------------------------------------------------------------------===// 28 // Standalone transformations and helpers. 29 //===----------------------------------------------------------------------===// 30 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 31 /// masking) fastpath and a slowpath. 32 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 33 /// newly created conditional upon function return. 34 /// To accomodate for the fact that the original vector.transfer indexing may be 35 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 36 /// scf.if op returns a view and values of type index. 37 /// At this time, only vector.transfer_read case is implemented. 38 /// 39 /// Example (a 2-D vector.transfer_read): 40 /// ``` 41 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 42 /// ``` 43 /// is transformed into: 44 /// ``` 45 /// %1:3 = scf.if (%inBounds) { 46 /// // fastpath, direct cast 47 /// memref.cast %A: memref<A...> to compatibleMemRefType 48 /// scf.yield %view : compatibleMemRefType, index, index 49 /// } else { 50 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 51 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 52 /// scf.yield %4 : compatibleMemRefType, index, index 53 // } 54 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 55 /// ``` 56 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 57 /// 58 /// Preconditions: 59 /// 1. `xferOp.permutation_map()` must be a minor identity map 60 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 61 /// must be equal. This will be relaxed in the future but requires 62 /// rank-reducing subviews. 63 LogicalResult splitFullAndPartialTransfer( 64 RewriterBase &b, VectorTransferOpInterface xferOp, 65 VectorTransformsOptions options = VectorTransformsOptions(), 66 scf::IfOp *ifOp = nullptr); 67 68 struct DistributeOps { 69 ExtractMapOp extract; 70 InsertMapOp insert; 71 }; 72 73 /// Distribute a N-D vector pointwise operation over a range of given ids taking 74 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or 75 /// SPMD id). This transformation only inserts 76 /// vector.extract_map/vector.insert_map. It is meant to be used with 77 /// canonicalizations pattern to propagate and fold the vector 78 /// insert_map/extract_map operations. 79 /// Transforms: 80 // %v = arith.addf %a, %b : vector<32xf32> 81 /// to: 82 /// %v = arith.addf %a, %b : vector<32xf32> 83 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> 84 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> 85 Optional<DistributeOps> 86 distributPointwiseVectorOp(OpBuilder &builder, Operation *op, 87 ArrayRef<Value> id, ArrayRef<int64_t> multiplicity, 88 const AffineMap &map); 89 90 /// Implements transfer op write to read forwarding and dead transfer write 91 /// optimizations. 92 void transferOpflowOpt(Operation *rootOp); 93 94 } // namespace vector 95 } // namespace mlir 96 97 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H 98