1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H
11 
12 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14 
15 namespace mlir {
16 class MLIRContext;
17 class VectorTransferOpInterface;
18 class RewritePatternSet;
19 class RewriterBase;
20 
21 namespace scf {
22 class IfOp;
23 } // namespace scf
24 
25 namespace vector {
26 
27 //===----------------------------------------------------------------------===//
28 // Standalone transformations and helpers.
29 //===----------------------------------------------------------------------===//
30 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
31 /// masking) fastpath and a slowpath.
32 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
33 /// newly created conditional upon function return.
34 /// To accomodate for the fact that the original vector.transfer indexing may be
35 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
36 /// scf.if op returns a view and values of type index.
37 /// At this time, only vector.transfer_read case is implemented.
38 ///
39 /// Example (a 2-D vector.transfer_read):
40 /// ```
41 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
42 /// ```
43 /// is transformed into:
44 /// ```
45 ///    %1:3 = scf.if (%inBounds) {
46 ///      // fastpath, direct cast
47 ///      memref.cast %A: memref<A...> to compatibleMemRefType
48 ///      scf.yield %view : compatibleMemRefType, index, index
49 ///    } else {
50 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
51 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
52 ///      scf.yield %4 : compatibleMemRefType, index, index
53 //     }
54 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
55 /// ```
56 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
57 ///
58 /// Preconditions:
59 ///  1. `xferOp.permutation_map()` must be a minor identity map
60 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
61 ///  must be equal. This will be relaxed in the future but requires
62 ///  rank-reducing subviews.
63 LogicalResult splitFullAndPartialTransfer(
64     RewriterBase &b, VectorTransferOpInterface xferOp,
65     VectorTransformsOptions options = VectorTransformsOptions(),
66     scf::IfOp *ifOp = nullptr);
67 
68 struct DistributeOps {
69   ExtractMapOp extract;
70   InsertMapOp insert;
71 };
72 
73 /// Distribute a N-D vector pointwise operation over a range of given ids taking
74 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
75 /// SPMD id). This transformation only inserts
76 /// vector.extract_map/vector.insert_map. It is meant to be used with
77 /// canonicalizations pattern to propagate and fold the vector
78 /// insert_map/extract_map operations.
79 /// Transforms:
80 //  %v = arith.addf %a, %b : vector<32xf32>
81 /// to:
82 /// %v = arith.addf %a, %b : vector<32xf32>
83 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
84 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
85 Optional<DistributeOps>
86 distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
87                            ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
88                            const AffineMap &map);
89 
90 /// Implements transfer op write to read forwarding and dead transfer write
91 /// optimizations.
92 void transferOpflowOpt(Operation *rootOp);
93 
94 } // namespace vector
95 } // namespace mlir
96 
97 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H
98