1 //===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for affine structures
10 // (expressions, maps, sets), and other utilities relying on such analysis.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "mlir/Interfaces/ViewLikeInterface.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "affine-analysis"
31 
32 using namespace mlir;
33 using namespace presburger;
34 
35 /// Get the value that is being reduced by `pos`-th reduction in the loop if
36 /// such a reduction can be performed by affine parallel loops. This assumes
37 /// floating-point operations are commutative. On success, `kind` will be the
38 /// reduction kind suitable for use in affine parallel loop builder. If the
39 /// reduction is not supported, returns null.
getSupportedReduction(AffineForOp forOp,unsigned pos,arith::AtomicRMWKind & kind)40 static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
41                                    arith::AtomicRMWKind &kind) {
42   SmallVector<Operation *> combinerOps;
43   Value reducedVal =
44       matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
45   if (!reducedVal)
46     return nullptr;
47 
48   // Expected only one combiner operation.
49   if (combinerOps.size() > 1)
50     return nullptr;
51 
52   Operation *combinerOp = combinerOps.back();
53   Optional<arith::AtomicRMWKind> maybeKind =
54       TypeSwitch<Operation *, Optional<arith::AtomicRMWKind>>(combinerOp)
55           .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
56           .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
57           .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
58           .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
59           .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
60           .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
61           .Case([](arith::MinFOp) { return arith::AtomicRMWKind::minf; })
62           .Case([](arith::MaxFOp) { return arith::AtomicRMWKind::maxf; })
63           .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
64           .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
65           .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
66           .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
67           .Default([](Operation *) -> Optional<arith::AtomicRMWKind> {
68             // TODO: AtomicRMW supports other kinds of reductions this is
69             // currently not detecting, add those when the need arises.
70             return llvm::None;
71           });
72   if (!maybeKind)
73     return nullptr;
74 
75   kind = *maybeKind;
76   return reducedVal;
77 }
78 
79 /// Populate `supportedReductions` with descriptors of the supported reductions.
getSupportedReductions(AffineForOp forOp,SmallVectorImpl<LoopReduction> & supportedReductions)80 void mlir::getSupportedReductions(
81     AffineForOp forOp, SmallVectorImpl<LoopReduction> &supportedReductions) {
82   unsigned numIterArgs = forOp.getNumIterOperands();
83   if (numIterArgs == 0)
84     return;
85   supportedReductions.reserve(numIterArgs);
86   for (unsigned i = 0; i < numIterArgs; ++i) {
87     arith::AtomicRMWKind kind;
88     if (Value value = getSupportedReduction(forOp, i, kind))
89       supportedReductions.emplace_back(LoopReduction{kind, i, value});
90   }
91 }
92 
93 /// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
94 /// provided, populates it with descriptors of the parallelizable reductions and
95 /// treats them as not preventing parallelization.
isLoopParallel(AffineForOp forOp,SmallVectorImpl<LoopReduction> * parallelReductions)96 bool mlir::isLoopParallel(AffineForOp forOp,
97                           SmallVectorImpl<LoopReduction> *parallelReductions) {
98   unsigned numIterArgs = forOp.getNumIterOperands();
99 
100   // Loop is not parallel if it has SSA loop-carried dependences and reduction
101   // detection is not requested.
102   if (numIterArgs > 0 && !parallelReductions)
103     return false;
104 
105   // Find supported reductions of requested.
106   if (parallelReductions) {
107     getSupportedReductions(forOp, *parallelReductions);
108     // Return later to allow for identifying all parallel reductions even if the
109     // loop is not parallel.
110     if (parallelReductions->size() != numIterArgs)
111       return false;
112   }
113 
114   // Check memory dependences.
115   return isLoopMemoryParallel(forOp);
116 }
117 
118 /// Returns true if `v` is allocated locally to `enclosingOp` -- i.e., it is
119 /// allocated by an operation nested within `enclosingOp`.
isLocallyDefined(Value v,Operation * enclosingOp)120 static bool isLocallyDefined(Value v, Operation *enclosingOp) {
121   Operation *defOp = v.getDefiningOp();
122   if (!defOp)
123     return false;
124 
125   if (hasSingleEffect<MemoryEffects::Allocate>(defOp, v) &&
126       enclosingOp->isProperAncestor(defOp))
127     return true;
128 
129   // Aliasing ops.
130   auto viewOp = dyn_cast<ViewLikeOpInterface>(defOp);
131   return viewOp && isLocallyDefined(viewOp.getViewSource(), enclosingOp);
132 }
133 
isLoopMemoryParallel(AffineForOp forOp)134 bool mlir::isLoopMemoryParallel(AffineForOp forOp) {
135   // Any memref-typed iteration arguments are treated as serializing.
136   if (llvm::any_of(forOp.getResultTypes(),
137                    [](Type type) { return type.isa<BaseMemRefType>(); }))
138     return false;
139 
140   // Collect all load and store ops in loop nest rooted at 'forOp'.
141   SmallVector<Operation *, 8> loadAndStoreOps;
142   auto walkResult = forOp.walk([&](Operation *op) -> WalkResult {
143     if (auto readOp = dyn_cast<AffineReadOpInterface>(op)) {
144       // Memrefs that are allocated inside `forOp` need not be considered.
145       if (!isLocallyDefined(readOp.getMemRef(), forOp))
146         loadAndStoreOps.push_back(op);
147     } else if (auto writeOp = dyn_cast<AffineWriteOpInterface>(op)) {
148       // Filter out stores the same way as above.
149       if (!isLocallyDefined(writeOp.getMemRef(), forOp))
150         loadAndStoreOps.push_back(op);
151     } else if (!isa<AffineForOp, AffineYieldOp, AffineIfOp>(op) &&
152                !hasSingleEffect<MemoryEffects::Allocate>(op) &&
153                !MemoryEffectOpInterface::hasNoEffect(op)) {
154       // Alloc-like ops inside `forOp` are fine (they don't impact parallelism)
155       // as long as they don't escape the loop (which has been checked above).
156       return WalkResult::interrupt();
157     }
158 
159     return WalkResult::advance();
160   });
161 
162   // Stop early if the loop has unknown ops with side effects.
163   if (walkResult.wasInterrupted())
164     return false;
165 
166   // Dep check depth would be number of enclosing loops + 1.
167   unsigned depth = getNestingDepth(forOp) + 1;
168 
169   // Check dependences between all pairs of ops in 'loadAndStoreOps'.
170   for (auto *srcOp : loadAndStoreOps) {
171     MemRefAccess srcAccess(srcOp);
172     for (auto *dstOp : loadAndStoreOps) {
173       MemRefAccess dstAccess(dstOp);
174       FlatAffineValueConstraints dependenceConstraints;
175       DependenceResult result = checkMemrefAccessDependence(
176           srcAccess, dstAccess, depth, &dependenceConstraints,
177           /*dependenceComponents=*/nullptr);
178       if (result.value != DependenceResult::NoDependence)
179         return false;
180     }
181   }
182   return true;
183 }
184 
185 /// Returns the sequence of AffineApplyOp Operations operation in
186 /// 'affineApplyOps', which are reachable via a search starting from 'operands',
187 /// and ending at operands which are not defined by AffineApplyOps.
188 // TODO: Add a method to AffineApplyOp which forward substitutes the
189 // AffineApplyOp into any user AffineApplyOps.
getReachableAffineApplyOps(ArrayRef<Value> operands,SmallVectorImpl<Operation * > & affineApplyOps)190 void mlir::getReachableAffineApplyOps(
191     ArrayRef<Value> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
192   struct State {
193     // The ssa value for this node in the DFS traversal.
194     Value value;
195     // The operand index of 'value' to explore next during DFS traversal.
196     unsigned operandIndex;
197   };
198   SmallVector<State, 4> worklist;
199   for (auto operand : operands) {
200     worklist.push_back({operand, 0});
201   }
202 
203   while (!worklist.empty()) {
204     State &state = worklist.back();
205     auto *opInst = state.value.getDefiningOp();
206     // Note: getDefiningOp will return nullptr if the operand is not an
207     // Operation (i.e. block argument), which is a terminator for the search.
208     if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
209       worklist.pop_back();
210       continue;
211     }
212 
213     if (state.operandIndex == 0) {
214       // Pre-Visit: Add 'opInst' to reachable sequence.
215       affineApplyOps.push_back(opInst);
216     }
217     if (state.operandIndex < opInst->getNumOperands()) {
218       // Visit: Add next 'affineApplyOp' operand to worklist.
219       // Get next operand to visit at 'operandIndex'.
220       auto nextOperand = opInst->getOperand(state.operandIndex);
221       // Increment 'operandIndex' in 'state'.
222       ++state.operandIndex;
223       // Add 'nextOperand' to worklist.
224       worklist.push_back({nextOperand, 0});
225     } else {
226       // Post-visit: done visiting operands AffineApplyOp, pop off stack.
227       worklist.pop_back();
228     }
229   }
230 }
231 
232 // Builds a system of constraints with dimensional variables corresponding to
233 // the loop IVs of the forOps appearing in that order. Any symbols founds in
234 // the bound operands are added as symbols in the system. Returns failure for
235 // the yet unimplemented cases.
236 // TODO: Handle non-unit steps through local variables or stride information in
237 // FlatAffineValueConstraints. (For eg., by using iv - lb % step = 0 and/or by
238 // introducing a method in FlatAffineValueConstraints
239 // setExprStride(ArrayRef<int64_t> expr, int64_t stride)
getIndexSet(MutableArrayRef<Operation * > ops,FlatAffineValueConstraints * domain)240 LogicalResult mlir::getIndexSet(MutableArrayRef<Operation *> ops,
241                                 FlatAffineValueConstraints *domain) {
242   SmallVector<Value, 4> indices;
243   SmallVector<AffineForOp, 8> forOps;
244 
245   for (Operation *op : ops) {
246     assert((isa<AffineForOp, AffineIfOp>(op)) &&
247            "ops should have either AffineForOp or AffineIfOp");
248     if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
249       forOps.push_back(forOp);
250   }
251   extractForInductionVars(forOps, &indices);
252   // Reset while associated Values in 'indices' to the domain.
253   domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
254   for (Operation *op : ops) {
255     // Add constraints from forOp's bounds.
256     if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
257       if (failed(domain->addAffineForOpDomain(forOp)))
258         return failure();
259     } else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
260       domain->addAffineIfOpDomain(ifOp);
261     }
262   }
263   return success();
264 }
265 
266 /// Computes the iteration domain for 'op' and populates 'indexSet', which
267 /// encapsulates the constraints involving loops surrounding 'op' and
268 /// potentially involving any Function symbols. The dimensional variables in
269 /// 'indexSet' correspond to the loops surrounding 'op' from outermost to
270 /// innermost.
getOpIndexSet(Operation * op,FlatAffineValueConstraints * indexSet)271 static LogicalResult getOpIndexSet(Operation *op,
272                                    FlatAffineValueConstraints *indexSet) {
273   SmallVector<Operation *, 4> ops;
274   getEnclosingAffineForAndIfOps(*op, &ops);
275   return getIndexSet(ops, indexSet);
276 }
277 
278 // Returns the number of outer loop common to 'src/dstDomain'.
279 // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
280 static unsigned
getNumCommonLoops(const FlatAffineValueConstraints & srcDomain,const FlatAffineValueConstraints & dstDomain,SmallVectorImpl<AffineForOp> * commonLoops=nullptr)281 getNumCommonLoops(const FlatAffineValueConstraints &srcDomain,
282                   const FlatAffineValueConstraints &dstDomain,
283                   SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
284   // Find the number of common loops shared by src and dst accesses.
285   unsigned minNumLoops =
286       std::min(srcDomain.getNumDimVars(), dstDomain.getNumDimVars());
287   unsigned numCommonLoops = 0;
288   for (unsigned i = 0; i < minNumLoops; ++i) {
289     if (!isForInductionVar(srcDomain.getValue(i)) ||
290         !isForInductionVar(dstDomain.getValue(i)) ||
291         srcDomain.getValue(i) != dstDomain.getValue(i))
292       break;
293     if (commonLoops != nullptr)
294       commonLoops->push_back(getForInductionVarOwner(srcDomain.getValue(i)));
295     ++numCommonLoops;
296   }
297   if (commonLoops != nullptr)
298     assert(commonLoops->size() == numCommonLoops);
299   return numCommonLoops;
300 }
301 
302 /// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
getCommonBlock(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,const FlatAffineValueConstraints & srcDomain,unsigned numCommonLoops)303 static Block *getCommonBlock(const MemRefAccess &srcAccess,
304                              const MemRefAccess &dstAccess,
305                              const FlatAffineValueConstraints &srcDomain,
306                              unsigned numCommonLoops) {
307   // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
308   // search terminates when either an op with the `AffineScope` trait or
309   // `endBlock` is reached.
310   auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
311                                       SmallVector<Block *, 4> &ancestorBlocks,
312                                       Block *endBlock = nullptr) {
313     Block *currBlock = access.opInst->getBlock();
314     // Loop terminates when the currBlock is nullptr or equals to the endBlock,
315     // or its parent operation holds an affine scope.
316     while (currBlock && currBlock != endBlock &&
317            !currBlock->getParentOp()->hasTrait<OpTrait::AffineScope>()) {
318       ancestorBlocks.push_back(currBlock);
319       currBlock = currBlock->getParentOp()->getBlock();
320     }
321   };
322 
323   if (numCommonLoops == 0) {
324     Block *block = srcAccess.opInst->getBlock();
325     while (!llvm::isa<func::FuncOp>(block->getParentOp())) {
326       block = block->getParentOp()->getBlock();
327     }
328     return block;
329   }
330   Value commonForIV = srcDomain.getValue(numCommonLoops - 1);
331   AffineForOp forOp = getForInductionVarOwner(commonForIV);
332   assert(forOp && "commonForValue was not an induction variable");
333 
334   // Find the closest common block including those in AffineIf.
335   SmallVector<Block *, 4> srcAncestorBlocks, dstAncestorBlocks;
336   getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody());
337   getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody());
338 
339   Block *commonBlock = forOp.getBody();
340   for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1;
341        i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
342        i--, j--)
343     commonBlock = srcAncestorBlocks[i];
344 
345   return commonBlock;
346 }
347 
348 // Returns true if the ancestor operation of 'srcAccess' appears before the
349 // ancestor operation of 'dstAccess' in the common ancestral block. Returns
350 // false otherwise.
351 // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
352 // the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that
353 // 'numCommonLoops' is the number of contiguous surrounding outer loops.
srcAppearsBeforeDstInAncestralBlock(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,const FlatAffineValueConstraints & srcDomain,unsigned numCommonLoops)354 static bool srcAppearsBeforeDstInAncestralBlock(
355     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
356     const FlatAffineValueConstraints &srcDomain, unsigned numCommonLoops) {
357   // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
358   auto *commonBlock =
359       getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
360   // Check the dominance relationship between the respective ancestors of the
361   // src and dst in the Block of the innermost among the common loops.
362   auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst);
363   assert(srcInst != nullptr);
364   auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst);
365   assert(dstInst != nullptr);
366 
367   // Determine whether dstInst comes after srcInst.
368   return srcInst->isBeforeInBlock(dstInst);
369 }
370 
371 // Adds ordering constraints to 'dependenceDomain' based on number of loops
372 // common to 'src/dstDomain' and requested 'loopDepth'.
373 // Note that 'loopDepth' cannot exceed the number of common loops plus one.
374 // EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
375 // *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
376 // *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
377 // *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
378 static void
addOrderingConstraints(const FlatAffineValueConstraints & srcDomain,const FlatAffineValueConstraints & dstDomain,unsigned loopDepth,FlatAffineValueConstraints * dependenceDomain)379 addOrderingConstraints(const FlatAffineValueConstraints &srcDomain,
380                        const FlatAffineValueConstraints &dstDomain,
381                        unsigned loopDepth,
382                        FlatAffineValueConstraints *dependenceDomain) {
383   unsigned numCols = dependenceDomain->getNumCols();
384   SmallVector<int64_t, 4> eq(numCols);
385   unsigned numSrcDims = srcDomain.getNumDimVars();
386   unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
387   unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
388   for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
389     std::fill(eq.begin(), eq.end(), 0);
390     eq[i] = -1;
391     eq[i + numSrcDims] = 1;
392     if (i == loopDepth - 1) {
393       eq[numCols - 1] = -1;
394       dependenceDomain->addInequality(eq);
395     } else {
396       dependenceDomain->addEquality(eq);
397     }
398   }
399 }
400 
401 // Computes distance and direction vectors in 'dependences', by adding
402 // variables to 'dependenceDomain' which represent the difference of the IVs,
403 // eliminating all other variables, and reading off distance vectors from
404 // equality constraints (if possible), and direction vectors from inequalities.
computeDirectionVector(const FlatAffineValueConstraints & srcDomain,const FlatAffineValueConstraints & dstDomain,unsigned loopDepth,FlatAffineValueConstraints * dependenceDomain,SmallVector<DependenceComponent,2> * dependenceComponents)405 static void computeDirectionVector(
406     const FlatAffineValueConstraints &srcDomain,
407     const FlatAffineValueConstraints &dstDomain, unsigned loopDepth,
408     FlatAffineValueConstraints *dependenceDomain,
409     SmallVector<DependenceComponent, 2> *dependenceComponents) {
410   // Find the number of common loops shared by src and dst accesses.
411   SmallVector<AffineForOp, 4> commonLoops;
412   unsigned numCommonLoops =
413       getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
414   if (numCommonLoops == 0)
415     return;
416   // Compute direction vectors for requested loop depth.
417   unsigned numIdsToEliminate = dependenceDomain->getNumVars();
418   // Add new variables to 'dependenceDomain' to represent the direction
419   // constraints for each shared loop.
420   dependenceDomain->insertDimVar(/*pos=*/0, /*num=*/numCommonLoops);
421 
422   // Add equality constraints for each common loop, setting newly introduced
423   // variable at column 'j' to the 'dst' IV minus the 'src IV.
424   SmallVector<int64_t, 4> eq;
425   eq.resize(dependenceDomain->getNumCols());
426   unsigned numSrcDims = srcDomain.getNumDimVars();
427   // Constraint variables format:
428   // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant]
429   for (unsigned j = 0; j < numCommonLoops; ++j) {
430     std::fill(eq.begin(), eq.end(), 0);
431     eq[j] = 1;
432     eq[j + numCommonLoops] = 1;
433     eq[j + numCommonLoops + numSrcDims] = -1;
434     dependenceDomain->addEquality(eq);
435   }
436 
437   // Eliminate all variables other than the direction variables just added.
438   dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate);
439 
440   // Scan each common loop variable column and set direction vectors based
441   // on eliminated constraint system.
442   dependenceComponents->resize(numCommonLoops);
443   for (unsigned j = 0; j < numCommonLoops; ++j) {
444     (*dependenceComponents)[j].op = commonLoops[j].getOperation();
445     auto lbConst = dependenceDomain->getConstantBound(IntegerPolyhedron::LB, j);
446     (*dependenceComponents)[j].lb =
447         lbConst.value_or(std::numeric_limits<int64_t>::min());
448     auto ubConst = dependenceDomain->getConstantBound(IntegerPolyhedron::UB, j);
449     (*dependenceComponents)[j].ub =
450         ubConst.value_or(std::numeric_limits<int64_t>::max());
451   }
452 }
453 
getAccessRelation(FlatAffineRelation & rel) const454 LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
455   // Create set corresponding to domain of access.
456   FlatAffineValueConstraints domain;
457   if (failed(getOpIndexSet(opInst, &domain)))
458     return failure();
459 
460   // Get access relation from access map.
461   AffineValueMap accessValueMap;
462   getAccessMap(&accessValueMap);
463   if (failed(getRelationFromMap(accessValueMap, rel)))
464     return failure();
465 
466   FlatAffineRelation domainRel(rel.getNumDomainDims(), /*numRangeDims=*/0,
467                                domain);
468 
469   // Merge and align domain ids of `ret` and ids of `domain`. Since the domain
470   // of the access map is a subset of the domain of access, the domain ids of
471   // `ret` are guranteed to be a subset of ids of `domain`.
472   for (unsigned i = 0, e = domain.getNumDimVars(); i < e; ++i) {
473     unsigned loc;
474     if (rel.findVar(domain.getValue(i), &loc)) {
475       rel.swapVar(i, loc);
476     } else {
477       rel.insertDomainVar(i);
478       rel.setValue(i, domain.getValue(i));
479     }
480   }
481 
482   // Append domain constraints to `rel`.
483   domainRel.appendRangeVar(rel.getNumRangeDims());
484   domainRel.mergeSymbolVars(rel);
485   domainRel.mergeLocalVars(rel);
486   rel.append(domainRel);
487 
488   return success();
489 }
490 
491 // Populates 'accessMap' with composition of AffineApplyOps reachable from
492 // indices of MemRefAccess.
getAccessMap(AffineValueMap * accessMap) const493 void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
494   // Get affine map from AffineLoad/Store.
495   AffineMap map;
496   if (auto loadOp = dyn_cast<AffineReadOpInterface>(opInst))
497     map = loadOp.getAffineMap();
498   else
499     map = cast<AffineWriteOpInterface>(opInst).getAffineMap();
500 
501   SmallVector<Value, 8> operands(indices.begin(), indices.end());
502   fullyComposeAffineMapAndOperands(&map, &operands);
503   map = simplifyAffineMap(map);
504   canonicalizeMapAndOperands(&map, &operands);
505   accessMap->reset(map, operands);
506 }
507 
508 // Builds a flat affine constraint system to check if there exists a dependence
509 // between memref accesses 'srcAccess' and 'dstAccess'.
510 // Returns 'NoDependence' if the accesses can be definitively shown not to
511 // access the same element.
512 // Returns 'HasDependence' if the accesses do access the same element.
513 // Returns 'Failure' if an error or unsupported case was encountered.
514 // If a dependence exists, returns in 'dependenceComponents' a direction
515 // vector for the dependence, with a component for each loop IV in loops
516 // common to both accesses (see Dependence in AffineAnalysis.h for details).
517 //
518 // The memref access dependence check is comprised of the following steps:
519 // *) Build access relation for each access. An access relation maps elements
520 //    of an iteration domain to the element(s) of an array domain accessed by
521 //    that iteration of the associated statement through some array reference.
522 // *) Compute the dependence relation by composing access relation of
523 //    `srcAccess` with the inverse of access relation of `dstAccess`.
524 //    Doing this builds a relation between iteration domain of `srcAccess`
525 //    to the iteration domain of `dstAccess` which access the same memory
526 //    location.
527 // *) Add ordering constraints for `srcAccess` to be accessed before
528 //    `dstAccess`.
529 //
530 // This method builds a constraint system with the following column format:
531 //
532 //  [src-dim-variables, dst-dim-variables, symbols, constant]
533 //
534 // For example, given the following MLIR code with "source" and "destination"
535 // accesses to the same memref label, and symbols %M, %N, %K:
536 //
537 //   affine.for %i0 = 0 to 100 {
538 //     affine.for %i1 = 0 to 50 {
539 //       %a0 = affine.apply
540 //         (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
541 //       // Source memref access.
542 //       store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
543 //     }
544 //   }
545 //
546 //   affine.for %i2 = 0 to 100 {
547 //     affine.for %i3 = 0 to 50 {
548 //       %a1 = affine.apply
549 //         (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
550 //       // Destination memref access.
551 //       %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
552 //     }
553 //   }
554 //
555 // The access relation for `srcAccess` would be the following:
556 //
557 //   [src_dim0, src_dim1, mem_dim0, mem_dim1,  %N,   %M,  const]
558 //       2        -4       -1         0         1     0     0     = 0
559 //       0         3        0        -1         0    -1     0     = 0
560 //       1         0        0         0         0     0     0    >= 0
561 //      -1         0        0         0         0     0     100  >= 0
562 //       0         1        0         0         0     0     0    >= 0
563 //       0        -1        0         0         0     0     50   >= 0
564 //
565 //  The access relation for `dstAccess` would be the following:
566 //
567 //   [dst_dim0, dst_dim1, mem_dim0, mem_dim1,  %M,   %K,  const]
568 //       7         9       -1         0        -1     0     0     = 0
569 //       0         11       0        -1         0    -1     0     = 0
570 //       1         0        0         0         0     0     0    >= 0
571 //      -1         0        0         0         0     0     100  >= 0
572 //       0         1        0         0         0     0     0    >= 0
573 //       0        -1        0         0         0     0     50   >= 0
574 //
575 //  The equalities in the above relations correspond to the access maps while
576 //  the inequalities corresspond to the iteration domain constraints.
577 //
578 // The dependence relation formed:
579 //
580 //   [src_dim0, src_dim1, dst_dim0, dst_dim1,  %M,   %N,   %K,  const]
581 //      2         -4        -7        -9        1     1     0     0    = 0
582 //      0          3         0        -11      -1     0     1     0    = 0
583 //       1         0         0         0        0     0     0     0    >= 0
584 //      -1         0         0         0        0     0     0     100  >= 0
585 //       0         1         0         0        0     0     0     0    >= 0
586 //       0        -1         0         0        0     0     0     50   >= 0
587 //       0         0         1         0        0     0     0     0    >= 0
588 //       0         0        -1         0        0     0     0     100  >= 0
589 //       0         0         0         1        0     0     0     0    >= 0
590 //       0         0         0        -1        0     0     0     50   >= 0
591 //
592 //
593 // TODO: Support AffineExprs mod/floordiv/ceildiv.
checkMemrefAccessDependence(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,unsigned loopDepth,FlatAffineValueConstraints * dependenceConstraints,SmallVector<DependenceComponent,2> * dependenceComponents,bool allowRAR)594 DependenceResult mlir::checkMemrefAccessDependence(
595     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
596     unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints,
597     SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) {
598   LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
599                           << Twine(loopDepth) << " between:\n";);
600   LLVM_DEBUG(srcAccess.opInst->dump(););
601   LLVM_DEBUG(dstAccess.opInst->dump(););
602 
603   // Return 'NoDependence' if these accesses do not access the same memref.
604   if (srcAccess.memref != dstAccess.memref)
605     return DependenceResult::NoDependence;
606 
607   // Return 'NoDependence' if one of these accesses is not an
608   // AffineWriteOpInterface.
609   if (!allowRAR && !isa<AffineWriteOpInterface>(srcAccess.opInst) &&
610       !isa<AffineWriteOpInterface>(dstAccess.opInst))
611     return DependenceResult::NoDependence;
612 
613   // Create access relation from each MemRefAccess.
614   FlatAffineRelation srcRel, dstRel;
615   if (failed(srcAccess.getAccessRelation(srcRel)))
616     return DependenceResult::Failure;
617   if (failed(dstAccess.getAccessRelation(dstRel)))
618     return DependenceResult::Failure;
619 
620   FlatAffineValueConstraints srcDomain = srcRel.getDomainSet();
621   FlatAffineValueConstraints dstDomain = dstRel.getDomainSet();
622 
623   // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
624   // operation of 'srcAccess' does not properly dominate the ancestor
625   // operation of 'dstAccess' in the same common operation block.
626   // Note: this check is skipped if 'allowRAR' is true, because because RAR
627   // deps can exist irrespective of lexicographic ordering b/w src and dst.
628   unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
629   assert(loopDepth <= numCommonLoops + 1);
630   if (!allowRAR && loopDepth > numCommonLoops &&
631       !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
632                                            numCommonLoops)) {
633     return DependenceResult::NoDependence;
634   }
635 
636   // Compute the dependence relation by composing `srcRel` with the inverse of
637   // `dstRel`. Doing this builds a relation between iteration domain of
638   // `srcAccess` to the iteration domain of `dstAccess` which access the same
639   // memory locations.
640   dstRel.inverse();
641   dstRel.compose(srcRel);
642   *dependenceConstraints = dstRel;
643 
644   // Add 'src' happens before 'dst' ordering constraints.
645   addOrderingConstraints(srcDomain, dstDomain, loopDepth,
646                          dependenceConstraints);
647 
648   // Return 'NoDependence' if the solution space is empty: no dependence.
649   if (dependenceConstraints->isEmpty())
650     return DependenceResult::NoDependence;
651 
652   // Compute dependence direction vector and return true.
653   if (dependenceComponents != nullptr)
654     computeDirectionVector(srcDomain, dstDomain, loopDepth,
655                            dependenceConstraints, dependenceComponents);
656 
657   LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
658   LLVM_DEBUG(dependenceConstraints->dump());
659   return DependenceResult::HasDependence;
660 }
661 
662 /// Gathers dependence components for dependences between all ops in loop nest
663 /// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
getDependenceComponents(AffineForOp forOp,unsigned maxLoopDepth,std::vector<SmallVector<DependenceComponent,2>> * depCompsVec)664 void mlir::getDependenceComponents(
665     AffineForOp forOp, unsigned maxLoopDepth,
666     std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) {
667   // Collect all load and store ops in loop nest rooted at 'forOp'.
668   SmallVector<Operation *, 8> loadAndStoreOps;
669   forOp->walk([&](Operation *op) {
670     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
671       loadAndStoreOps.push_back(op);
672   });
673 
674   unsigned numOps = loadAndStoreOps.size();
675   for (unsigned d = 1; d <= maxLoopDepth; ++d) {
676     for (unsigned i = 0; i < numOps; ++i) {
677       auto *srcOp = loadAndStoreOps[i];
678       MemRefAccess srcAccess(srcOp);
679       for (unsigned j = 0; j < numOps; ++j) {
680         auto *dstOp = loadAndStoreOps[j];
681         MemRefAccess dstAccess(dstOp);
682 
683         FlatAffineValueConstraints dependenceConstraints;
684         SmallVector<DependenceComponent, 2> depComps;
685         // TODO: Explore whether it would be profitable to pre-compute and store
686         // deps instead of repeatedly checking.
687         DependenceResult result = checkMemrefAccessDependence(
688             srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
689         if (hasDependence(result))
690           depCompsVec->push_back(depComps);
691       }
692     }
693   }
694 }
695