1755dc07dSRiver Riddle //===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
2755dc07dSRiver Riddle //
3755dc07dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4755dc07dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5755dc07dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6755dc07dSRiver Riddle //
7755dc07dSRiver Riddle //===----------------------------------------------------------------------===//
8755dc07dSRiver Riddle //
9755dc07dSRiver Riddle // This file implements miscellaneous loop analysis routines.
10755dc07dSRiver Riddle //
11755dc07dSRiver Riddle //===----------------------------------------------------------------------===//
12755dc07dSRiver Riddle 
13755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
14755dc07dSRiver Riddle 
15755dc07dSRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
16755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
18755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
19755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
20755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21755dc07dSRiver Riddle #include "mlir/Support/MathExtras.h"
22755dc07dSRiver Riddle 
23755dc07dSRiver Riddle #include "llvm/ADT/DenseSet.h"
24755dc07dSRiver Riddle #include "llvm/ADT/SmallPtrSet.h"
25755dc07dSRiver Riddle #include "llvm/ADT/SmallString.h"
26755dc07dSRiver Riddle #include <type_traits>
27755dc07dSRiver Riddle 
28755dc07dSRiver Riddle using namespace mlir;
29755dc07dSRiver Riddle 
30755dc07dSRiver Riddle /// Returns the trip count of the loop as an affine expression if the latter is
31755dc07dSRiver Riddle /// expressible as an affine expression, and nullptr otherwise. The trip count
32755dc07dSRiver Riddle /// expression is simplified before returning. This method only utilizes map
33755dc07dSRiver Riddle /// composition to construct lower and upper bounds before computing the trip
34755dc07dSRiver Riddle /// count expressions.
getTripCountMapAndOperands(AffineForOp forOp,AffineMap * tripCountMap,SmallVectorImpl<Value> * tripCountOperands)35755dc07dSRiver Riddle void mlir::getTripCountMapAndOperands(
36755dc07dSRiver Riddle     AffineForOp forOp, AffineMap *tripCountMap,
37755dc07dSRiver Riddle     SmallVectorImpl<Value> *tripCountOperands) {
38755dc07dSRiver Riddle   MLIRContext *context = forOp.getContext();
39755dc07dSRiver Riddle   int64_t step = forOp.getStep();
40755dc07dSRiver Riddle   int64_t loopSpan;
41755dc07dSRiver Riddle   if (forOp.hasConstantBounds()) {
42755dc07dSRiver Riddle     int64_t lb = forOp.getConstantLowerBound();
43755dc07dSRiver Riddle     int64_t ub = forOp.getConstantUpperBound();
44755dc07dSRiver Riddle     loopSpan = ub - lb;
45755dc07dSRiver Riddle     if (loopSpan < 0)
46755dc07dSRiver Riddle       loopSpan = 0;
47755dc07dSRiver Riddle     *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
48755dc07dSRiver Riddle     tripCountOperands->clear();
49755dc07dSRiver Riddle     return;
50755dc07dSRiver Riddle   }
51755dc07dSRiver Riddle   auto lbMap = forOp.getLowerBoundMap();
52755dc07dSRiver Riddle   auto ubMap = forOp.getUpperBoundMap();
53755dc07dSRiver Riddle   if (lbMap.getNumResults() != 1) {
54755dc07dSRiver Riddle     *tripCountMap = AffineMap();
55755dc07dSRiver Riddle     return;
56755dc07dSRiver Riddle   }
57755dc07dSRiver Riddle 
58755dc07dSRiver Riddle   // Difference of each upper bound expression from the single lower bound
59755dc07dSRiver Riddle   // expression (divided by the step) provides the expressions for the trip
60755dc07dSRiver Riddle   // count map.
61755dc07dSRiver Riddle   AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
62755dc07dSRiver Riddle 
63755dc07dSRiver Riddle   SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
64755dc07dSRiver Riddle                                          lbMap.getResult(0));
65755dc07dSRiver Riddle   auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
66755dc07dSRiver Riddle                                    lbSplatExpr, context);
67755dc07dSRiver Riddle   AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
68755dc07dSRiver Riddle 
69755dc07dSRiver Riddle   AffineValueMap tripCountValueMap;
70755dc07dSRiver Riddle   AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
71755dc07dSRiver Riddle   for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
72755dc07dSRiver Riddle     tripCountValueMap.setResult(i,
73755dc07dSRiver Riddle                                 tripCountValueMap.getResult(i).ceilDiv(step));
74755dc07dSRiver Riddle 
75755dc07dSRiver Riddle   *tripCountMap = tripCountValueMap.getAffineMap();
76755dc07dSRiver Riddle   tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
77755dc07dSRiver Riddle                             tripCountValueMap.getOperands().end());
78755dc07dSRiver Riddle }
79755dc07dSRiver Riddle 
80755dc07dSRiver Riddle /// Returns the trip count of the loop if it's a constant, None otherwise. This
81755dc07dSRiver Riddle /// method uses affine expression analysis (in turn using getTripCount) and is
82755dc07dSRiver Riddle /// able to determine constant trip count in non-trivial cases.
getConstantTripCount(AffineForOp forOp)83755dc07dSRiver Riddle Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) {
84755dc07dSRiver Riddle   SmallVector<Value, 4> operands;
85755dc07dSRiver Riddle   AffineMap map;
86755dc07dSRiver Riddle   getTripCountMapAndOperands(forOp, &map, &operands);
87755dc07dSRiver Riddle 
88755dc07dSRiver Riddle   if (!map)
89755dc07dSRiver Riddle     return None;
90755dc07dSRiver Riddle 
91755dc07dSRiver Riddle   // Take the min if all trip counts are constant.
92755dc07dSRiver Riddle   Optional<uint64_t> tripCount;
93755dc07dSRiver Riddle   for (auto resultExpr : map.getResults()) {
94755dc07dSRiver Riddle     if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
95491d2701SKazu Hirata       if (tripCount.has_value())
96*c27d8152SKazu Hirata         tripCount = std::min(tripCount.value(),
973b7c3a65SKazu Hirata                              static_cast<uint64_t>(constExpr.getValue()));
98755dc07dSRiver Riddle       else
99755dc07dSRiver Riddle         tripCount = constExpr.getValue();
100755dc07dSRiver Riddle     } else
101755dc07dSRiver Riddle       return None;
102755dc07dSRiver Riddle   }
103755dc07dSRiver Riddle   return tripCount;
104755dc07dSRiver Riddle }
105755dc07dSRiver Riddle 
106755dc07dSRiver Riddle /// Returns the greatest known integral divisor of the trip count. Affine
107755dc07dSRiver Riddle /// expression analysis is used (indirectly through getTripCount), and
108755dc07dSRiver Riddle /// this method is thus able to determine non-trivial divisors.
getLargestDivisorOfTripCount(AffineForOp forOp)109755dc07dSRiver Riddle uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) {
110755dc07dSRiver Riddle   SmallVector<Value, 4> operands;
111755dc07dSRiver Riddle   AffineMap map;
112755dc07dSRiver Riddle   getTripCountMapAndOperands(forOp, &map, &operands);
113755dc07dSRiver Riddle 
114755dc07dSRiver Riddle   if (!map)
115755dc07dSRiver Riddle     return 1;
116755dc07dSRiver Riddle 
117755dc07dSRiver Riddle   // The largest divisor of the trip count is the GCD of the individual largest
118755dc07dSRiver Riddle   // divisors.
119755dc07dSRiver Riddle   assert(map.getNumResults() >= 1 && "expected one or more results");
120755dc07dSRiver Riddle   Optional<uint64_t> gcd;
121755dc07dSRiver Riddle   for (auto resultExpr : map.getResults()) {
122755dc07dSRiver Riddle     uint64_t thisGcd;
123755dc07dSRiver Riddle     if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
124755dc07dSRiver Riddle       uint64_t tripCount = constExpr.getValue();
125755dc07dSRiver Riddle       // 0 iteration loops (greatest divisor is 2^64 - 1).
126755dc07dSRiver Riddle       if (tripCount == 0)
127755dc07dSRiver Riddle         thisGcd = std::numeric_limits<uint64_t>::max();
128755dc07dSRiver Riddle       else
129755dc07dSRiver Riddle         // The greatest divisor is the trip count.
130755dc07dSRiver Riddle         thisGcd = tripCount;
131755dc07dSRiver Riddle     } else {
132755dc07dSRiver Riddle       // Trip count is not a known constant; return its largest known divisor.
133755dc07dSRiver Riddle       thisGcd = resultExpr.getLargestKnownDivisor();
134755dc07dSRiver Riddle     }
135491d2701SKazu Hirata     if (gcd.has_value())
136*c27d8152SKazu Hirata       gcd = llvm::GreatestCommonDivisor64(gcd.value(), thisGcd);
137755dc07dSRiver Riddle     else
138755dc07dSRiver Riddle       gcd = thisGcd;
139755dc07dSRiver Riddle   }
140491d2701SKazu Hirata   assert(gcd.has_value() && "value expected per above logic");
141*c27d8152SKazu Hirata   return gcd.value();
142755dc07dSRiver Riddle }
143755dc07dSRiver Riddle 
144755dc07dSRiver Riddle /// Given an induction variable `iv` of type AffineForOp and an access `index`
145755dc07dSRiver Riddle /// of type index, returns `true` if `index` is independent of `iv` and
146755dc07dSRiver Riddle /// false otherwise. The determination supports composition with at most one
147755dc07dSRiver Riddle /// AffineApplyOp. The 'at most one AffineApplyOp' comes from the fact that
148755dc07dSRiver Riddle /// the composition of AffineApplyOp needs to be canonicalized by construction
149755dc07dSRiver Riddle /// to avoid writing code that composes arbitrary numbers of AffineApplyOps
150755dc07dSRiver Riddle /// everywhere. To achieve this, at the very least, the compose-affine-apply
151755dc07dSRiver Riddle /// pass must have been run.
152755dc07dSRiver Riddle ///
153755dc07dSRiver Riddle /// Prerequisites:
154755dc07dSRiver Riddle ///   1. `iv` and `index` of the proper type;
155755dc07dSRiver Riddle ///   2. at most one reachable AffineApplyOp from index;
156755dc07dSRiver Riddle ///
157755dc07dSRiver Riddle /// Returns false in cases with more than one AffineApplyOp, this is
158755dc07dSRiver Riddle /// conservative.
isAccessIndexInvariant(Value iv,Value index)159755dc07dSRiver Riddle static bool isAccessIndexInvariant(Value iv, Value index) {
160755dc07dSRiver Riddle   assert(isForInductionVar(iv) && "iv must be a AffineForOp");
161755dc07dSRiver Riddle   assert(index.getType().isa<IndexType>() && "index must be of IndexType");
162755dc07dSRiver Riddle   SmallVector<Operation *, 4> affineApplyOps;
163755dc07dSRiver Riddle   getReachableAffineApplyOps({index}, affineApplyOps);
164755dc07dSRiver Riddle 
165755dc07dSRiver Riddle   if (affineApplyOps.empty()) {
166755dc07dSRiver Riddle     // Pointer equality test because of Value pointer semantics.
167755dc07dSRiver Riddle     return index != iv;
168755dc07dSRiver Riddle   }
169755dc07dSRiver Riddle 
170755dc07dSRiver Riddle   if (affineApplyOps.size() > 1) {
171755dc07dSRiver Riddle     affineApplyOps[0]->emitRemark(
172755dc07dSRiver Riddle         "CompositionAffineMapsPass must have been run: there should be at most "
173755dc07dSRiver Riddle         "one AffineApplyOp, returning false conservatively.");
174755dc07dSRiver Riddle     return false;
175755dc07dSRiver Riddle   }
176755dc07dSRiver Riddle 
177755dc07dSRiver Riddle   auto composeOp = cast<AffineApplyOp>(affineApplyOps[0]);
178755dc07dSRiver Riddle   // We need yet another level of indirection because the `dim` index of the
179755dc07dSRiver Riddle   // access may not correspond to the `dim` index of composeOp.
180755dc07dSRiver Riddle   return !composeOp.getAffineValueMap().isFunctionOf(0, iv);
181755dc07dSRiver Riddle }
182755dc07dSRiver Riddle 
getInvariantAccesses(Value iv,ArrayRef<Value> indices)183755dc07dSRiver Riddle DenseSet<Value> mlir::getInvariantAccesses(Value iv, ArrayRef<Value> indices) {
184755dc07dSRiver Riddle   DenseSet<Value> res;
185755dc07dSRiver Riddle   for (auto val : indices) {
186755dc07dSRiver Riddle     if (isAccessIndexInvariant(iv, val)) {
187755dc07dSRiver Riddle       res.insert(val);
188755dc07dSRiver Riddle     }
189755dc07dSRiver Riddle   }
190755dc07dSRiver Riddle   return res;
191755dc07dSRiver Riddle }
192755dc07dSRiver Riddle 
193755dc07dSRiver Riddle /// Given:
194755dc07dSRiver Riddle ///   1. an induction variable `iv` of type AffineForOp;
195755dc07dSRiver Riddle ///   2. a `memoryOp` of type const LoadOp& or const StoreOp&;
196755dc07dSRiver Riddle /// determines whether `memoryOp` has a contiguous access along `iv`. Contiguous
197755dc07dSRiver Riddle /// is defined as either invariant or varying only along a unique MemRef dim.
198755dc07dSRiver Riddle /// Upon success, the unique MemRef dim is written in `memRefDim` (or -1 to
199755dc07dSRiver Riddle /// convey the memRef access is invariant along `iv`).
200755dc07dSRiver Riddle ///
201755dc07dSRiver Riddle /// Prerequisites:
202755dc07dSRiver Riddle ///   1. `memRefDim` ~= nullptr;
203755dc07dSRiver Riddle ///   2. `iv` of the proper type;
204755dc07dSRiver Riddle ///   3. the MemRef accessed by `memoryOp` has no layout map or at most an
205755dc07dSRiver Riddle ///      identity layout map.
206755dc07dSRiver Riddle ///
207755dc07dSRiver Riddle /// Currently only supports no layoutMap or identity layoutMap in the MemRef.
208755dc07dSRiver Riddle /// Returns false if the MemRef has a non-identity layoutMap or more than 1
209755dc07dSRiver Riddle /// layoutMap. This is conservative.
210755dc07dSRiver Riddle ///
211755dc07dSRiver Riddle // TODO: check strides.
212755dc07dSRiver Riddle template <typename LoadOrStoreOp>
isContiguousAccess(Value iv,LoadOrStoreOp memoryOp,int * memRefDim)213755dc07dSRiver Riddle static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
214755dc07dSRiver Riddle                                int *memRefDim) {
215755dc07dSRiver Riddle   static_assert(
216755dc07dSRiver Riddle       llvm::is_one_of<LoadOrStoreOp, AffineLoadOp, AffineStoreOp>::value,
217755dc07dSRiver Riddle       "Must be called on either LoadOp or StoreOp");
218755dc07dSRiver Riddle   assert(memRefDim && "memRefDim == nullptr");
219755dc07dSRiver Riddle   auto memRefType = memoryOp.getMemRefType();
220755dc07dSRiver Riddle 
221755dc07dSRiver Riddle   if (!memRefType.getLayout().isIdentity())
222755dc07dSRiver Riddle     return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
223755dc07dSRiver Riddle 
224755dc07dSRiver Riddle   int uniqueVaryingIndexAlongIv = -1;
225755dc07dSRiver Riddle   auto accessMap = memoryOp.getAffineMap();
226755dc07dSRiver Riddle   SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands());
227755dc07dSRiver Riddle   unsigned numDims = accessMap.getNumDims();
228755dc07dSRiver Riddle   for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
229755dc07dSRiver Riddle     // Gather map operands used result expr 'i' in 'exprOperands'.
230755dc07dSRiver Riddle     SmallVector<Value, 4> exprOperands;
231755dc07dSRiver Riddle     auto resultExpr = accessMap.getResult(i);
232755dc07dSRiver Riddle     resultExpr.walk([&](AffineExpr expr) {
233755dc07dSRiver Riddle       if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
234755dc07dSRiver Riddle         exprOperands.push_back(mapOperands[dimExpr.getPosition()]);
235755dc07dSRiver Riddle       else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
236755dc07dSRiver Riddle         exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
237755dc07dSRiver Riddle     });
238755dc07dSRiver Riddle     // Check access invariance of each operand in 'exprOperands'.
239755dc07dSRiver Riddle     for (auto exprOperand : exprOperands) {
240755dc07dSRiver Riddle       if (!isAccessIndexInvariant(iv, exprOperand)) {
241755dc07dSRiver Riddle         if (uniqueVaryingIndexAlongIv != -1) {
242755dc07dSRiver Riddle           // 2+ varying indices -> do not vectorize along iv.
243755dc07dSRiver Riddle           return false;
244755dc07dSRiver Riddle         }
245755dc07dSRiver Riddle         uniqueVaryingIndexAlongIv = i;
246755dc07dSRiver Riddle       }
247755dc07dSRiver Riddle     }
248755dc07dSRiver Riddle   }
249755dc07dSRiver Riddle 
250755dc07dSRiver Riddle   if (uniqueVaryingIndexAlongIv == -1)
251755dc07dSRiver Riddle     *memRefDim = -1;
252755dc07dSRiver Riddle   else
253755dc07dSRiver Riddle     *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1);
254755dc07dSRiver Riddle   return true;
255755dc07dSRiver Riddle }
256755dc07dSRiver Riddle 
257755dc07dSRiver Riddle template <typename LoadOrStoreOp>
isVectorElement(LoadOrStoreOp memoryOp)258755dc07dSRiver Riddle static bool isVectorElement(LoadOrStoreOp memoryOp) {
259755dc07dSRiver Riddle   auto memRefType = memoryOp.getMemRefType();
260755dc07dSRiver Riddle   return memRefType.getElementType().template isa<VectorType>();
261755dc07dSRiver Riddle }
262755dc07dSRiver Riddle 
263755dc07dSRiver Riddle using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
264755dc07dSRiver Riddle 
265755dc07dSRiver Riddle static bool
isVectorizableLoopBodyWithOpCond(AffineForOp loop,const VectorizableOpFun & isVectorizableOp,NestedPattern & vectorTransferMatcher)266755dc07dSRiver Riddle isVectorizableLoopBodyWithOpCond(AffineForOp loop,
267755dc07dSRiver Riddle                                  const VectorizableOpFun &isVectorizableOp,
268755dc07dSRiver Riddle                                  NestedPattern &vectorTransferMatcher) {
269755dc07dSRiver Riddle   auto *forOp = loop.getOperation();
270755dc07dSRiver Riddle 
271755dc07dSRiver Riddle   // No vectorization across conditionals for now.
272755dc07dSRiver Riddle   auto conditionals = matcher::If();
273755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> conditionalsMatched;
274755dc07dSRiver Riddle   conditionals.match(forOp, &conditionalsMatched);
275755dc07dSRiver Riddle   if (!conditionalsMatched.empty()) {
276755dc07dSRiver Riddle     return false;
277755dc07dSRiver Riddle   }
278755dc07dSRiver Riddle 
279755dc07dSRiver Riddle   // No vectorization across unknown regions.
280755dc07dSRiver Riddle   auto regions = matcher::Op([](Operation &op) -> bool {
281755dc07dSRiver Riddle     return op.getNumRegions() != 0 && !isa<AffineIfOp, AffineForOp>(op);
282755dc07dSRiver Riddle   });
283755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> regionsMatched;
284755dc07dSRiver Riddle   regions.match(forOp, &regionsMatched);
285755dc07dSRiver Riddle   if (!regionsMatched.empty()) {
286755dc07dSRiver Riddle     return false;
287755dc07dSRiver Riddle   }
288755dc07dSRiver Riddle 
289755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> vectorTransfersMatched;
290755dc07dSRiver Riddle   vectorTransferMatcher.match(forOp, &vectorTransfersMatched);
291755dc07dSRiver Riddle   if (!vectorTransfersMatched.empty()) {
292755dc07dSRiver Riddle     return false;
293755dc07dSRiver Riddle   }
294755dc07dSRiver Riddle 
295755dc07dSRiver Riddle   auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
296755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> loadAndStoresMatched;
297755dc07dSRiver Riddle   loadAndStores.match(forOp, &loadAndStoresMatched);
298755dc07dSRiver Riddle   for (auto ls : loadAndStoresMatched) {
299755dc07dSRiver Riddle     auto *op = ls.getMatchedOperation();
300755dc07dSRiver Riddle     auto load = dyn_cast<AffineLoadOp>(op);
301755dc07dSRiver Riddle     auto store = dyn_cast<AffineStoreOp>(op);
302755dc07dSRiver Riddle     // Only scalar types are considered vectorizable, all load/store must be
303755dc07dSRiver Riddle     // vectorizable for a loop to qualify as vectorizable.
304755dc07dSRiver Riddle     // TODO: ponder whether we want to be more general here.
305755dc07dSRiver Riddle     bool vector = load ? isVectorElement(load) : isVectorElement(store);
306755dc07dSRiver Riddle     if (vector) {
307755dc07dSRiver Riddle       return false;
308755dc07dSRiver Riddle     }
309755dc07dSRiver Riddle     if (isVectorizableOp && !isVectorizableOp(loop, *op)) {
310755dc07dSRiver Riddle       return false;
311755dc07dSRiver Riddle     }
312755dc07dSRiver Riddle   }
313755dc07dSRiver Riddle   return true;
314755dc07dSRiver Riddle }
315755dc07dSRiver Riddle 
isVectorizableLoopBody(AffineForOp loop,int * memRefDim,NestedPattern & vectorTransferMatcher)316755dc07dSRiver Riddle bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
317755dc07dSRiver Riddle                                   NestedPattern &vectorTransferMatcher) {
318755dc07dSRiver Riddle   *memRefDim = -1;
319755dc07dSRiver Riddle   VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) {
320755dc07dSRiver Riddle     auto load = dyn_cast<AffineLoadOp>(op);
321755dc07dSRiver Riddle     auto store = dyn_cast<AffineStoreOp>(op);
322755dc07dSRiver Riddle     int thisOpMemRefDim = -1;
323755dc07dSRiver Riddle     bool isContiguous = load ? isContiguousAccess(loop.getInductionVar(), load,
324755dc07dSRiver Riddle                                                   &thisOpMemRefDim)
325755dc07dSRiver Riddle                              : isContiguousAccess(loop.getInductionVar(), store,
326755dc07dSRiver Riddle                                                   &thisOpMemRefDim);
327755dc07dSRiver Riddle     if (thisOpMemRefDim != -1) {
328755dc07dSRiver Riddle       // If memory accesses vary across different dimensions then the loop is
329755dc07dSRiver Riddle       // not vectorizable.
330755dc07dSRiver Riddle       if (*memRefDim != -1 && *memRefDim != thisOpMemRefDim)
331755dc07dSRiver Riddle         return false;
332755dc07dSRiver Riddle       *memRefDim = thisOpMemRefDim;
333755dc07dSRiver Riddle     }
334755dc07dSRiver Riddle     return isContiguous;
335755dc07dSRiver Riddle   });
336755dc07dSRiver Riddle   return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher);
337755dc07dSRiver Riddle }
338755dc07dSRiver Riddle 
isVectorizableLoopBody(AffineForOp loop,NestedPattern & vectorTransferMatcher)339755dc07dSRiver Riddle bool mlir::isVectorizableLoopBody(AffineForOp loop,
340755dc07dSRiver Riddle                                   NestedPattern &vectorTransferMatcher) {
341755dc07dSRiver Riddle   return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher);
342755dc07dSRiver Riddle }
343755dc07dSRiver Riddle 
344755dc07dSRiver Riddle /// Checks whether SSA dominance would be violated if a for op's body
345755dc07dSRiver Riddle /// operations are shifted by the specified shifts. This method checks if a
346755dc07dSRiver Riddle /// 'def' and all its uses have the same shift factor.
347755dc07dSRiver Riddle // TODO: extend this to check for memory-based dependence violation when we have
348755dc07dSRiver Riddle // the support.
isOpwiseShiftValid(AffineForOp forOp,ArrayRef<uint64_t> shifts)349755dc07dSRiver Riddle bool mlir::isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) {
350755dc07dSRiver Riddle   auto *forBody = forOp.getBody();
351755dc07dSRiver Riddle   assert(shifts.size() == forBody->getOperations().size());
352755dc07dSRiver Riddle 
353755dc07dSRiver Riddle   // Work backwards over the body of the block so that the shift of a use's
354755dc07dSRiver Riddle   // ancestor operation in the block gets recorded before it's looked up.
355755dc07dSRiver Riddle   DenseMap<Operation *, uint64_t> forBodyShift;
356755dc07dSRiver Riddle   for (const auto &it :
357755dc07dSRiver Riddle        llvm::enumerate(llvm::reverse(forBody->getOperations()))) {
358755dc07dSRiver Riddle     auto &op = it.value();
359755dc07dSRiver Riddle 
360755dc07dSRiver Riddle     // Get the index of the current operation, note that we are iterating in
361755dc07dSRiver Riddle     // reverse so we need to fix it up.
362755dc07dSRiver Riddle     size_t index = shifts.size() - it.index() - 1;
363755dc07dSRiver Riddle 
364755dc07dSRiver Riddle     // Remember the shift of this operation.
365755dc07dSRiver Riddle     uint64_t shift = shifts[index];
366755dc07dSRiver Riddle     forBodyShift.try_emplace(&op, shift);
367755dc07dSRiver Riddle 
368755dc07dSRiver Riddle     // Validate the results of this operation if it were to be shifted.
369755dc07dSRiver Riddle     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
370755dc07dSRiver Riddle       Value result = op.getResult(i);
371755dc07dSRiver Riddle       for (auto *user : result.getUsers()) {
372755dc07dSRiver Riddle         // If an ancestor operation doesn't lie in the block of forOp,
373755dc07dSRiver Riddle         // there is no shift to check.
374755dc07dSRiver Riddle         if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) {
375755dc07dSRiver Riddle           assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map");
376755dc07dSRiver Riddle           if (shift != forBodyShift[ancOp])
377755dc07dSRiver Riddle             return false;
378755dc07dSRiver Riddle         }
379755dc07dSRiver Riddle       }
380755dc07dSRiver Riddle     }
381755dc07dSRiver Riddle   }
382755dc07dSRiver Riddle   return true;
383755dc07dSRiver Riddle }
384