1 //===- VectorUtils.cpp - MLIR Utilities for VectorOps   ------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utility methods for working with the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14 
15 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Support/MathExtras.h"
28 #include <numeric>
29 
30 #include "llvm/ADT/DenseSet.h"
31 #include "llvm/ADT/SetVector.h"
32 
33 using namespace mlir;
34 
35 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
36 /// the type of `source`.
37 Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
38                                       int64_t dim) {
39   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
40     return b.createOrFold<memref::DimOp>(loc, source, dim);
41   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
42     return b.createOrFold<tensor::DimOp>(loc, source, dim);
43   llvm_unreachable("Expected MemRefType or TensorType");
44 }
45 
46 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
47                                        CombiningKind kind, Value v1, Value v2) {
48   Type t1 = getElementTypeOrSelf(v1.getType());
49   Type t2 = getElementTypeOrSelf(v2.getType());
50   switch (kind) {
51   case CombiningKind::ADD:
52     if (t1.isIntOrIndex() && t2.isIntOrIndex())
53       return b.createOrFold<arith::AddIOp>(loc, v1, v2);
54     else if (t1.isa<FloatType>() && t2.isa<FloatType>())
55       return b.createOrFold<arith::AddFOp>(loc, v1, v2);
56     llvm_unreachable("invalid value types for ADD reduction");
57   case CombiningKind::AND:
58     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
59     return b.createOrFold<arith::AndIOp>(loc, v1, v2);
60   case CombiningKind::MAXF:
61     assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
62            "expected float values");
63     return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
64   case CombiningKind::MINF:
65     assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
66            "expected float values");
67     return b.createOrFold<arith::MinFOp>(loc, v1, v2);
68   case CombiningKind::MAXSI:
69     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
70     return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
71   case CombiningKind::MINSI:
72     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
73     return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
74   case CombiningKind::MAXUI:
75     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
76     return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
77   case CombiningKind::MINUI:
78     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
79     return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
80   case CombiningKind::MUL:
81     if (t1.isIntOrIndex() && t2.isIntOrIndex())
82       return b.createOrFold<arith::MulIOp>(loc, v1, v2);
83     else if (t1.isa<FloatType>() && t2.isa<FloatType>())
84       return b.createOrFold<arith::MulFOp>(loc, v1, v2);
85     llvm_unreachable("invalid value types for MUL reduction");
86   case CombiningKind::OR:
87     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
88     return b.createOrFold<arith::OrIOp>(loc, v1, v2);
89   case CombiningKind::XOR:
90     assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
91     return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
92   };
93   llvm_unreachable("unknown CombiningKind");
94 }
95 
96 /// Return the number of elements of basis, `0` if empty.
97 int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
98   if (basis.empty())
99     return 0;
100   return std::accumulate(basis.begin(), basis.end(), 1,
101                          std::multiplies<int64_t>());
102 }
103 
104 SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
105                                              ArrayRef<int64_t> sizes) {
106   int64_t rank = shape.size();
107   // Compute the count for each dimension.
108   SmallVector<int64_t, 4> sliceDimCounts(rank);
109   for (int64_t r = 0; r < rank; ++r)
110     sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
111   // Use that to compute the slice stride for each dimension.
112   SmallVector<int64_t, 4> sliceStrides(rank);
113   sliceStrides[rank - 1] = 1;
114   for (int64_t r = rank - 2; r >= 0; --r)
115     sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1];
116   return sliceStrides;
117 }
118 
119 SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
120     ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
121   SmallVector<int64_t, 4> result;
122   for (auto it : llvm::zip(vectorOffsets, sizes))
123     result.push_back(std::get<0>(it) * std::get<1>(it));
124   return result;
125 }
126 
127 Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
128                                                    ArrayRef<int64_t> subShape) {
129   if (superShape.size() < subShape.size()) {
130     return Optional<SmallVector<int64_t, 4>>();
131   }
132 
133   // Starting from the end, compute the integer divisors.
134   std::vector<int64_t> result;
135   result.reserve(superShape.size());
136   int64_t superSize = 0, subSize = 0;
137   for (auto it :
138        llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
139     std::tie(superSize, subSize) = it;
140     assert(superSize > 0 && "superSize must be > 0");
141     assert(subSize > 0 && "subSize must be > 0");
142 
143     // If integral division does not occur, return and let the caller decide.
144     if (superSize % subSize != 0)
145       return None;
146     result.push_back(superSize / subSize);
147   }
148 
149   // At this point we computed the ratio (in reverse) for the common
150   // size. Fill with the remaining entries from the super-vector shape (still in
151   // reverse).
152   int commonSize = subShape.size();
153   std::copy(superShape.rbegin() + commonSize, superShape.rend(),
154             std::back_inserter(result));
155 
156   assert(result.size() == superShape.size() &&
157          "super to sub shape ratio is not of the same size as the super rank");
158 
159   // Reverse again to get it back in the proper order and return.
160   return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
161 }
162 
163 Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
164                                                    VectorType subVectorType) {
165   assert(superVectorType.getElementType() == subVectorType.getElementType() &&
166          "vector types must be of the same elemental type");
167   return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
168 }
169 
170 /// Constructs a permutation map from memref indices to vector dimension.
171 ///
172 /// The implementation uses the knowledge of the mapping of enclosing loop to
173 /// vector dimension. `enclosingLoopToVectorDim` carries this information as a
174 /// map with:
175 ///   - keys representing "vectorized enclosing loops";
176 ///   - values representing the corresponding vector dimension.
177 /// The algorithm traverses "vectorized enclosing loops" and extracts the
178 /// at-most-one MemRef index that is invariant along said loop. This index is
179 /// guaranteed to be at most one by construction: otherwise the MemRef is not
180 /// vectorizable.
181 /// If this invariant index is found, it is added to the permutation_map at the
182 /// proper vector dimension.
183 /// If no index is found to be invariant, 0 is added to the permutation_map and
184 /// corresponds to a vector broadcast along that dimension.
185 ///
186 /// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
187 /// signalling that no permutation map can be constructed given
188 /// `enclosingLoopToVectorDim`.
189 ///
190 /// Examples can be found in the documentation of `makePermutationMap`, in the
191 /// header file.
192 static AffineMap makePermutationMap(
193     ArrayRef<Value> indices,
194     const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
195   if (enclosingLoopToVectorDim.empty())
196     return AffineMap();
197   MLIRContext *context =
198       enclosingLoopToVectorDim.begin()->getFirst()->getContext();
199   SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
200                                   getAffineConstantExpr(0, context));
201 
202   for (auto kvp : enclosingLoopToVectorDim) {
203     assert(kvp.second < perm.size());
204     auto invariants = getInvariantAccesses(
205         cast<AffineForOp>(kvp.first).getInductionVar(), indices);
206     unsigned numIndices = indices.size();
207     unsigned countInvariantIndices = 0;
208     for (unsigned dim = 0; dim < numIndices; ++dim) {
209       if (!invariants.count(indices[dim])) {
210         assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
211                "permutationMap already has an entry along dim");
212         perm[kvp.second] = getAffineDimExpr(dim, context);
213       } else {
214         ++countInvariantIndices;
215       }
216     }
217     assert((countInvariantIndices == numIndices ||
218             countInvariantIndices == numIndices - 1) &&
219            "Vectorization prerequisite violated: at most 1 index may be "
220            "invariant wrt a vectorized loop");
221   }
222   return AffineMap::get(indices.size(), 0, perm, context);
223 }
224 
225 /// Implementation detail that walks up the parents and records the ones with
226 /// the specified type.
227 /// TODO: could also be implemented as a collect parents followed by a
228 /// filter and made available outside this file.
229 template <typename T>
230 static SetVector<Operation *> getParentsOfType(Block *block) {
231   SetVector<Operation *> res;
232   auto *current = block->getParentOp();
233   while (current) {
234     if (auto typedParent = dyn_cast<T>(current)) {
235       assert(res.count(current) == 0 && "Already inserted");
236       res.insert(current);
237     }
238     current = current->getParentOp();
239   }
240   return res;
241 }
242 
243 /// Returns the enclosing AffineForOp, from closest to farthest.
244 static SetVector<Operation *> getEnclosingforOps(Block *block) {
245   return getParentsOfType<AffineForOp>(block);
246 }
247 
248 AffineMap mlir::makePermutationMap(
249     Block *insertPoint, ArrayRef<Value> indices,
250     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
251   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
252   auto enclosingLoops = getEnclosingforOps(insertPoint);
253   for (auto *forInst : enclosingLoops) {
254     auto it = loopToVectorDim.find(forInst);
255     if (it != loopToVectorDim.end()) {
256       enclosingLoopToVectorDim.insert(*it);
257     }
258   }
259   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
260 }
261 
262 AffineMap mlir::makePermutationMap(
263     Operation *op, ArrayRef<Value> indices,
264     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
265   return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
266 }
267 
268 bool matcher::operatesOnSuperVectorsOf(Operation &op,
269                                        VectorType subVectorType) {
270   // First, extract the vector type and distinguish between:
271   //   a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
272   //      vector.transfer_write); and
273   //   b. ops that *may* lower a super-vector (all other ops).
274   // The ops that *may* lower a super-vector only do so if the super-vector to
275   // sub-vector ratio exists. The ops that *must* lower a super-vector are
276   // explicitly checked for this property.
277   /// TODO: there should be a single function for all ops to do this so we
278   /// do not have to special case. Maybe a trait, or just a method, unclear atm.
279   bool mustDivide = false;
280   (void)mustDivide;
281   VectorType superVectorType;
282   if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
283     superVectorType = transfer.getVectorType();
284     mustDivide = true;
285   } else if (op.getNumResults() == 0) {
286     if (!isa<func::ReturnOp>(op)) {
287       op.emitError("NYI: assuming only return operations can have 0 "
288                    " results at this point");
289     }
290     return false;
291   } else if (op.getNumResults() == 1) {
292     if (auto v = op.getResult(0).getType().dyn_cast<VectorType>()) {
293       superVectorType = v;
294     } else {
295       // Not a vector type.
296       return false;
297     }
298   } else {
299     // Not a vector.transfer and has more than 1 result, fail hard for now to
300     // wake us up when something changes.
301     op.emitError("NYI: operation has more than 1 result");
302     return false;
303   }
304 
305   // Get the ratio.
306   auto ratio = shapeRatio(superVectorType, subVectorType);
307 
308   // Sanity check.
309   assert((ratio.hasValue() || !mustDivide) &&
310          "vector.transfer operation in which super-vector size is not an"
311          " integer multiple of sub-vector size");
312 
313   // This catches cases that are not strictly necessary to have multiplicity but
314   // still aren't divisible by the sub-vector shape.
315   // This could be useful information if we wanted to reshape at the level of
316   // the vector type (but we would have to look at the compute and distinguish
317   // between parallel, reduction and possibly other cases.
318   return ratio.hasValue();
319 }
320