199ef9eebSMatthias Springer //===- VectorUtils.cpp - MLIR Utilities for VectorOps   ------------------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements utility methods for working with the Vector dialect.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer 
1399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1499ef9eebSMatthias Springer 
1599ef9eebSMatthias Springer #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
2199ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2299ef9eebSMatthias Springer #include "mlir/IR/Builders.h"
2399ef9eebSMatthias Springer #include "mlir/IR/IntegerSet.h"
2499ef9eebSMatthias Springer #include "mlir/IR/Operation.h"
259b5a3d14SMatthias Springer #include "mlir/IR/TypeUtilities.h"
2699ef9eebSMatthias Springer #include "mlir/Support/LLVM.h"
2799ef9eebSMatthias Springer #include "mlir/Support/MathExtras.h"
2899ef9eebSMatthias Springer #include <numeric>
2999ef9eebSMatthias Springer 
3099ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
3199ef9eebSMatthias Springer #include "llvm/ADT/SetVector.h"
3299ef9eebSMatthias Springer 
3399ef9eebSMatthias Springer using namespace mlir;
3499ef9eebSMatthias Springer 
3599ef9eebSMatthias Springer /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
3699ef9eebSMatthias Springer /// the type of `source`.
createOrFoldDimOp(OpBuilder & b,Location loc,Value source,int64_t dim)3799ef9eebSMatthias Springer Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
3899ef9eebSMatthias Springer                                       int64_t dim) {
3999ef9eebSMatthias Springer   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
4099ef9eebSMatthias Springer     return b.createOrFold<memref::DimOp>(loc, source, dim);
4199ef9eebSMatthias Springer   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
4299ef9eebSMatthias Springer     return b.createOrFold<tensor::DimOp>(loc, source, dim);
4399ef9eebSMatthias Springer   llvm_unreachable("Expected MemRefType or TensorType");
4499ef9eebSMatthias Springer }
4599ef9eebSMatthias Springer 
4699ef9eebSMatthias Springer /// Return the number of elements of basis, `0` if empty.
computeMaxLinearIndex(ArrayRef<int64_t> basis)4799ef9eebSMatthias Springer int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
4899ef9eebSMatthias Springer   if (basis.empty())
4999ef9eebSMatthias Springer     return 0;
5099ef9eebSMatthias Springer   return std::accumulate(basis.begin(), basis.end(), 1,
5199ef9eebSMatthias Springer                          std::multiplies<int64_t>());
5299ef9eebSMatthias Springer }
5399ef9eebSMatthias Springer 
computeStrides(ArrayRef<int64_t> shape,ArrayRef<int64_t> sizes)5499ef9eebSMatthias Springer SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
5599ef9eebSMatthias Springer                                              ArrayRef<int64_t> sizes) {
5699ef9eebSMatthias Springer   int64_t rank = shape.size();
5799ef9eebSMatthias Springer   // Compute the count for each dimension.
5899ef9eebSMatthias Springer   SmallVector<int64_t, 4> sliceDimCounts(rank);
5999ef9eebSMatthias Springer   for (int64_t r = 0; r < rank; ++r)
6099ef9eebSMatthias Springer     sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
6199ef9eebSMatthias Springer   // Use that to compute the slice stride for each dimension.
6299ef9eebSMatthias Springer   SmallVector<int64_t, 4> sliceStrides(rank);
6399ef9eebSMatthias Springer   sliceStrides[rank - 1] = 1;
6499ef9eebSMatthias Springer   for (int64_t r = rank - 2; r >= 0; --r)
6599ef9eebSMatthias Springer     sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1];
6699ef9eebSMatthias Springer   return sliceStrides;
6799ef9eebSMatthias Springer }
6899ef9eebSMatthias Springer 
computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,ArrayRef<int64_t> vectorOffsets)6999ef9eebSMatthias Springer SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
7099ef9eebSMatthias Springer     ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
7199ef9eebSMatthias Springer   SmallVector<int64_t, 4> result;
7299ef9eebSMatthias Springer   for (auto it : llvm::zip(vectorOffsets, sizes))
7399ef9eebSMatthias Springer     result.push_back(std::get<0>(it) * std::get<1>(it));
7499ef9eebSMatthias Springer   return result;
7599ef9eebSMatthias Springer }
7699ef9eebSMatthias Springer 
shapeRatio(ArrayRef<int64_t> superShape,ArrayRef<int64_t> subShape)7799ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
7899ef9eebSMatthias Springer                                                    ArrayRef<int64_t> subShape) {
7999ef9eebSMatthias Springer   if (superShape.size() < subShape.size()) {
8099ef9eebSMatthias Springer     return Optional<SmallVector<int64_t, 4>>();
8199ef9eebSMatthias Springer   }
8299ef9eebSMatthias Springer 
8399ef9eebSMatthias Springer   // Starting from the end, compute the integer divisors.
8499ef9eebSMatthias Springer   std::vector<int64_t> result;
8599ef9eebSMatthias Springer   result.reserve(superShape.size());
8699ef9eebSMatthias Springer   int64_t superSize = 0, subSize = 0;
8799ef9eebSMatthias Springer   for (auto it :
8899ef9eebSMatthias Springer        llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
8999ef9eebSMatthias Springer     std::tie(superSize, subSize) = it;
9099ef9eebSMatthias Springer     assert(superSize > 0 && "superSize must be > 0");
9199ef9eebSMatthias Springer     assert(subSize > 0 && "subSize must be > 0");
9299ef9eebSMatthias Springer 
9399ef9eebSMatthias Springer     // If integral division does not occur, return and let the caller decide.
9499ef9eebSMatthias Springer     if (superSize % subSize != 0)
9599ef9eebSMatthias Springer       return None;
9699ef9eebSMatthias Springer     result.push_back(superSize / subSize);
9799ef9eebSMatthias Springer   }
9899ef9eebSMatthias Springer 
9999ef9eebSMatthias Springer   // At this point we computed the ratio (in reverse) for the common
10099ef9eebSMatthias Springer   // size. Fill with the remaining entries from the super-vector shape (still in
10199ef9eebSMatthias Springer   // reverse).
10299ef9eebSMatthias Springer   int commonSize = subShape.size();
10399ef9eebSMatthias Springer   std::copy(superShape.rbegin() + commonSize, superShape.rend(),
10499ef9eebSMatthias Springer             std::back_inserter(result));
10599ef9eebSMatthias Springer 
10699ef9eebSMatthias Springer   assert(result.size() == superShape.size() &&
10799ef9eebSMatthias Springer          "super to sub shape ratio is not of the same size as the super rank");
10899ef9eebSMatthias Springer 
10999ef9eebSMatthias Springer   // Reverse again to get it back in the proper order and return.
11099ef9eebSMatthias Springer   return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
11199ef9eebSMatthias Springer }
11299ef9eebSMatthias Springer 
shapeRatio(VectorType superVectorType,VectorType subVectorType)11399ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
11499ef9eebSMatthias Springer                                                    VectorType subVectorType) {
11599ef9eebSMatthias Springer   assert(superVectorType.getElementType() == subVectorType.getElementType() &&
11699ef9eebSMatthias Springer          "vector types must be of the same elemental type");
11799ef9eebSMatthias Springer   return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
11899ef9eebSMatthias Springer }
11999ef9eebSMatthias Springer 
12099ef9eebSMatthias Springer /// Constructs a permutation map from memref indices to vector dimension.
12199ef9eebSMatthias Springer ///
12299ef9eebSMatthias Springer /// The implementation uses the knowledge of the mapping of enclosing loop to
12399ef9eebSMatthias Springer /// vector dimension. `enclosingLoopToVectorDim` carries this information as a
12499ef9eebSMatthias Springer /// map with:
12599ef9eebSMatthias Springer ///   - keys representing "vectorized enclosing loops";
12699ef9eebSMatthias Springer ///   - values representing the corresponding vector dimension.
12799ef9eebSMatthias Springer /// The algorithm traverses "vectorized enclosing loops" and extracts the
12899ef9eebSMatthias Springer /// at-most-one MemRef index that is invariant along said loop. This index is
12999ef9eebSMatthias Springer /// guaranteed to be at most one by construction: otherwise the MemRef is not
13099ef9eebSMatthias Springer /// vectorizable.
13199ef9eebSMatthias Springer /// If this invariant index is found, it is added to the permutation_map at the
13299ef9eebSMatthias Springer /// proper vector dimension.
13399ef9eebSMatthias Springer /// If no index is found to be invariant, 0 is added to the permutation_map and
13499ef9eebSMatthias Springer /// corresponds to a vector broadcast along that dimension.
13599ef9eebSMatthias Springer ///
13699ef9eebSMatthias Springer /// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
13799ef9eebSMatthias Springer /// signalling that no permutation map can be constructed given
13899ef9eebSMatthias Springer /// `enclosingLoopToVectorDim`.
13999ef9eebSMatthias Springer ///
14099ef9eebSMatthias Springer /// Examples can be found in the documentation of `makePermutationMap`, in the
14199ef9eebSMatthias Springer /// header file.
makePermutationMap(ArrayRef<Value> indices,const DenseMap<Operation *,unsigned> & enclosingLoopToVectorDim)14299ef9eebSMatthias Springer static AffineMap makePermutationMap(
14399ef9eebSMatthias Springer     ArrayRef<Value> indices,
14499ef9eebSMatthias Springer     const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
14599ef9eebSMatthias Springer   if (enclosingLoopToVectorDim.empty())
14699ef9eebSMatthias Springer     return AffineMap();
14799ef9eebSMatthias Springer   MLIRContext *context =
14899ef9eebSMatthias Springer       enclosingLoopToVectorDim.begin()->getFirst()->getContext();
14999ef9eebSMatthias Springer   SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
15099ef9eebSMatthias Springer                                   getAffineConstantExpr(0, context));
15199ef9eebSMatthias Springer 
15299ef9eebSMatthias Springer   for (auto kvp : enclosingLoopToVectorDim) {
15399ef9eebSMatthias Springer     assert(kvp.second < perm.size());
15499ef9eebSMatthias Springer     auto invariants = getInvariantAccesses(
15599ef9eebSMatthias Springer         cast<AffineForOp>(kvp.first).getInductionVar(), indices);
15699ef9eebSMatthias Springer     unsigned numIndices = indices.size();
15799ef9eebSMatthias Springer     unsigned countInvariantIndices = 0;
15899ef9eebSMatthias Springer     for (unsigned dim = 0; dim < numIndices; ++dim) {
15999ef9eebSMatthias Springer       if (!invariants.count(indices[dim])) {
16099ef9eebSMatthias Springer         assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
16199ef9eebSMatthias Springer                "permutationMap already has an entry along dim");
16299ef9eebSMatthias Springer         perm[kvp.second] = getAffineDimExpr(dim, context);
16399ef9eebSMatthias Springer       } else {
16499ef9eebSMatthias Springer         ++countInvariantIndices;
16599ef9eebSMatthias Springer       }
16699ef9eebSMatthias Springer     }
16799ef9eebSMatthias Springer     assert((countInvariantIndices == numIndices ||
16899ef9eebSMatthias Springer             countInvariantIndices == numIndices - 1) &&
16999ef9eebSMatthias Springer            "Vectorization prerequisite violated: at most 1 index may be "
17099ef9eebSMatthias Springer            "invariant wrt a vectorized loop");
17199ef9eebSMatthias Springer   }
17299ef9eebSMatthias Springer   return AffineMap::get(indices.size(), 0, perm, context);
17399ef9eebSMatthias Springer }
17499ef9eebSMatthias Springer 
17599ef9eebSMatthias Springer /// Implementation detail that walks up the parents and records the ones with
17699ef9eebSMatthias Springer /// the specified type.
17799ef9eebSMatthias Springer /// TODO: could also be implemented as a collect parents followed by a
17899ef9eebSMatthias Springer /// filter and made available outside this file.
17999ef9eebSMatthias Springer template <typename T>
getParentsOfType(Block * block)18099ef9eebSMatthias Springer static SetVector<Operation *> getParentsOfType(Block *block) {
18199ef9eebSMatthias Springer   SetVector<Operation *> res;
18299ef9eebSMatthias Springer   auto *current = block->getParentOp();
18399ef9eebSMatthias Springer   while (current) {
18499ef9eebSMatthias Springer     if (auto typedParent = dyn_cast<T>(current)) {
18599ef9eebSMatthias Springer       assert(res.count(current) == 0 && "Already inserted");
18699ef9eebSMatthias Springer       res.insert(current);
18799ef9eebSMatthias Springer     }
18899ef9eebSMatthias Springer     current = current->getParentOp();
18999ef9eebSMatthias Springer   }
19099ef9eebSMatthias Springer   return res;
19199ef9eebSMatthias Springer }
19299ef9eebSMatthias Springer 
19399ef9eebSMatthias Springer /// Returns the enclosing AffineForOp, from closest to farthest.
getEnclosingforOps(Block * block)19499ef9eebSMatthias Springer static SetVector<Operation *> getEnclosingforOps(Block *block) {
19599ef9eebSMatthias Springer   return getParentsOfType<AffineForOp>(block);
19699ef9eebSMatthias Springer }
19799ef9eebSMatthias Springer 
makePermutationMap(Block * insertPoint,ArrayRef<Value> indices,const DenseMap<Operation *,unsigned> & loopToVectorDim)19899ef9eebSMatthias Springer AffineMap mlir::makePermutationMap(
19999ef9eebSMatthias Springer     Block *insertPoint, ArrayRef<Value> indices,
20099ef9eebSMatthias Springer     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
20199ef9eebSMatthias Springer   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
20299ef9eebSMatthias Springer   auto enclosingLoops = getEnclosingforOps(insertPoint);
20399ef9eebSMatthias Springer   for (auto *forInst : enclosingLoops) {
20499ef9eebSMatthias Springer     auto it = loopToVectorDim.find(forInst);
20599ef9eebSMatthias Springer     if (it != loopToVectorDim.end()) {
20699ef9eebSMatthias Springer       enclosingLoopToVectorDim.insert(*it);
20799ef9eebSMatthias Springer     }
20899ef9eebSMatthias Springer   }
20999ef9eebSMatthias Springer   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
21099ef9eebSMatthias Springer }
21199ef9eebSMatthias Springer 
makePermutationMap(Operation * op,ArrayRef<Value> indices,const DenseMap<Operation *,unsigned> & loopToVectorDim)21299ef9eebSMatthias Springer AffineMap mlir::makePermutationMap(
21399ef9eebSMatthias Springer     Operation *op, ArrayRef<Value> indices,
21499ef9eebSMatthias Springer     const DenseMap<Operation *, unsigned> &loopToVectorDim) {
21599ef9eebSMatthias Springer   return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
21699ef9eebSMatthias Springer }
21799ef9eebSMatthias Springer 
operatesOnSuperVectorsOf(Operation & op,VectorType subVectorType)21899ef9eebSMatthias Springer bool matcher::operatesOnSuperVectorsOf(Operation &op,
21999ef9eebSMatthias Springer                                        VectorType subVectorType) {
22099ef9eebSMatthias Springer   // First, extract the vector type and distinguish between:
22199ef9eebSMatthias Springer   //   a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
22299ef9eebSMatthias Springer   //      vector.transfer_write); and
22399ef9eebSMatthias Springer   //   b. ops that *may* lower a super-vector (all other ops).
22499ef9eebSMatthias Springer   // The ops that *may* lower a super-vector only do so if the super-vector to
22599ef9eebSMatthias Springer   // sub-vector ratio exists. The ops that *must* lower a super-vector are
22699ef9eebSMatthias Springer   // explicitly checked for this property.
22799ef9eebSMatthias Springer   /// TODO: there should be a single function for all ops to do this so we
22899ef9eebSMatthias Springer   /// do not have to special case. Maybe a trait, or just a method, unclear atm.
22999ef9eebSMatthias Springer   bool mustDivide = false;
23099ef9eebSMatthias Springer   (void)mustDivide;
23199ef9eebSMatthias Springer   VectorType superVectorType;
23299ef9eebSMatthias Springer   if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
23399ef9eebSMatthias Springer     superVectorType = transfer.getVectorType();
23499ef9eebSMatthias Springer     mustDivide = true;
23599ef9eebSMatthias Springer   } else if (op.getNumResults() == 0) {
23623aa5a74SRiver Riddle     if (!isa<func::ReturnOp>(op)) {
23799ef9eebSMatthias Springer       op.emitError("NYI: assuming only return operations can have 0 "
23899ef9eebSMatthias Springer                    " results at this point");
23999ef9eebSMatthias Springer     }
24099ef9eebSMatthias Springer     return false;
24199ef9eebSMatthias Springer   } else if (op.getNumResults() == 1) {
24299ef9eebSMatthias Springer     if (auto v = op.getResult(0).getType().dyn_cast<VectorType>()) {
24399ef9eebSMatthias Springer       superVectorType = v;
24499ef9eebSMatthias Springer     } else {
24599ef9eebSMatthias Springer       // Not a vector type.
24699ef9eebSMatthias Springer       return false;
24799ef9eebSMatthias Springer     }
24899ef9eebSMatthias Springer   } else {
24999ef9eebSMatthias Springer     // Not a vector.transfer and has more than 1 result, fail hard for now to
25099ef9eebSMatthias Springer     // wake us up when something changes.
25199ef9eebSMatthias Springer     op.emitError("NYI: operation has more than 1 result");
25299ef9eebSMatthias Springer     return false;
25399ef9eebSMatthias Springer   }
25499ef9eebSMatthias Springer 
25599ef9eebSMatthias Springer   // Get the ratio.
25699ef9eebSMatthias Springer   auto ratio = shapeRatio(superVectorType, subVectorType);
25799ef9eebSMatthias Springer 
25899ef9eebSMatthias Springer   // Sanity check.
2595413bf1bSKazu Hirata   assert((ratio || !mustDivide) &&
26099ef9eebSMatthias Springer          "vector.transfer operation in which super-vector size is not an"
26199ef9eebSMatthias Springer          " integer multiple of sub-vector size");
26299ef9eebSMatthias Springer 
26399ef9eebSMatthias Springer   // This catches cases that are not strictly necessary to have multiplicity but
26499ef9eebSMatthias Springer   // still aren't divisible by the sub-vector shape.
26599ef9eebSMatthias Springer   // This could be useful information if we wanted to reshape at the level of
26699ef9eebSMatthias Springer   // the vector type (but we would have to look at the compute and distinguish
26799ef9eebSMatthias Springer   // between parallel, reduction and possibly other cases.
268*064a08cdSKazu Hirata   return ratio.has_value();
26999ef9eebSMatthias Springer }
270