1 //===------ Simplify.cpp ----------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Simplify a SCoP by removing unnecessary statements and accesses.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/Simplify.h"
15 #include "polly/ScopInfo.h"
16 #include "polly/ScopPass.h"
17 #include "polly/Support/GICHelper.h"
18 #include "polly/Support/ISLOStream.h"
19 #include "polly/Support/VirtualInstruction.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/Support/Debug.h"
22 #define DEBUG_TYPE "polly-simplify"
23 
24 using namespace llvm;
25 using namespace polly;
26 
27 namespace {
28 
29 STATISTIC(ScopsProcessed, "Number of SCoPs processed");
30 STATISTIC(ScopsModified, "Number of SCoPs simplified");
31 
32 STATISTIC(PairUnequalAccRels, "Number of Load-Store pairs NOT removed because "
33                               "of different access relations");
34 STATISTIC(InBetweenStore, "Number of Load-Store pairs NOT removed because "
35                           "there is another store between them");
36 STATISTIC(TotalOverwritesRemoved, "Number of removed overwritten writes");
37 STATISTIC(TotalRedundantWritesRemoved,
38           "Number of writes of same value removed in any SCoP");
39 STATISTIC(TotalEmptyPartialAccessesRemoved,
40           "Number of empty partial accesses removed");
41 STATISTIC(TotalDeadAccessesRemoved, "Number of dead accesses removed");
42 STATISTIC(TotalDeadInstructionsRemoved,
43           "Number of unused instructions removed");
44 STATISTIC(TotalStmtsRemoved, "Number of statements removed in any SCoP");
45 
46 static bool isImplicitRead(MemoryAccess *MA) {
47   return MA->isRead() && MA->isOriginalScalarKind();
48 }
49 
50 static bool isExplicitAccess(MemoryAccess *MA) {
51   return MA->isOriginalArrayKind();
52 }
53 
54 static bool isImplicitWrite(MemoryAccess *MA) {
55   return MA->isWrite() && MA->isOriginalScalarKind();
56 }
57 
58 /// Return a vector that contains MemoryAccesses in the order in
59 /// which they are executed.
60 ///
61 /// The order is:
62 /// - Implicit reads (BlockGenerator::generateScalarLoads)
63 /// - Explicit reads and writes (BlockGenerator::generateArrayLoad,
64 ///   BlockGenerator::generateArrayStore)
65 ///   - In block statements, the accesses are in order in which their
66 ///     instructions are executed.
67 ///   - In region statements, that order of execution is not predictable at
68 ///     compile-time.
69 /// - Implicit writes (BlockGenerator::generateScalarStores)
70 ///   The order in which implicit writes are executed relative to each other is
71 ///   undefined.
72 static SmallVector<MemoryAccess *, 32> getAccessesInOrder(ScopStmt &Stmt) {
73 
74   SmallVector<MemoryAccess *, 32> Accesses;
75 
76   for (MemoryAccess *MemAcc : Stmt)
77     if (isImplicitRead(MemAcc))
78       Accesses.push_back(MemAcc);
79 
80   for (MemoryAccess *MemAcc : Stmt)
81     if (isExplicitAccess(MemAcc))
82       Accesses.push_back(MemAcc);
83 
84   for (MemoryAccess *MemAcc : Stmt)
85     if (isImplicitWrite(MemAcc))
86       Accesses.push_back(MemAcc);
87 
88   return Accesses;
89 }
90 
91 class Simplify : public ScopPass {
92 private:
93   /// The last/current SCoP that is/has been processed.
94   Scop *S;
95 
96   /// Number of writes that are overwritten anyway.
97   int OverwritesRemoved = 0;
98 
99   /// Number of redundant writes removed from this SCoP.
100   int RedundantWritesRemoved = 0;
101 
102   /// Number of writes with empty access domain removed.
103   int EmptyPartialAccessesRemoved = 0;
104 
105   /// Number of unused accesses removed from this SCoP.
106   int DeadAccessesRemoved = 0;
107 
108   /// Number of unused instructions removed from this SCoP.
109   int DeadInstructionsRemoved = 0;
110 
111   /// Number of unnecessary statements removed from the SCoP.
112   int StmtsRemoved = 0;
113 
114   /// Return whether at least one simplification has been applied.
115   bool isModified() const {
116     return OverwritesRemoved > 0 || RedundantWritesRemoved > 0 ||
117            EmptyPartialAccessesRemoved > 0 || DeadAccessesRemoved > 0 ||
118            DeadInstructionsRemoved > 0 || StmtsRemoved > 0;
119   }
120 
121   MemoryAccess *getReadAccessForValue(ScopStmt *Stmt, llvm::Value *Val) {
122     if (!isa<Instruction>(Val))
123       return nullptr;
124 
125     for (auto *MA : *Stmt) {
126       if (!MA->isRead())
127         continue;
128       if (MA->getAccessValue() != Val)
129         continue;
130 
131       return MA;
132     }
133 
134     return nullptr;
135   }
136 
137   /// Return a write access that occurs between @p From and @p To.
138   ///
139   /// In region statements the order is ignored because we cannot predict it.
140   ///
141   /// @param Stmt    Statement of both writes.
142   /// @param From    Start looking after this access.
143   /// @param To      Stop looking at this access, with the access itself.
144   /// @param Targets Look for an access that may wrote to one of these elements.
145   ///
146   /// @return A write access between @p From and @p To that writes to at least
147   ///         one element in @p Targets.
148   MemoryAccess *hasWriteBetween(ScopStmt *Stmt, MemoryAccess *From,
149                                 MemoryAccess *To, isl::map Targets) {
150     auto TargetsSpace = Targets.get_space();
151 
152     bool Started = Stmt->isRegionStmt();
153     auto Accesses = getAccessesInOrder(*Stmt);
154     for (auto *Acc : Accesses) {
155       if (Acc->isLatestScalarKind())
156         continue;
157 
158       if (Stmt->isBlockStmt() && From == Acc) {
159         assert(!Started);
160         Started = true;
161         continue;
162       }
163       if (Stmt->isBlockStmt() && To == Acc) {
164         assert(Started);
165         return nullptr;
166       }
167       if (!Started)
168         continue;
169 
170       if (!Acc->isWrite())
171         continue;
172 
173       isl::map AccRel = Acc->getAccessRelation();
174       auto AccRelSpace = AccRel.get_space();
175 
176       // Spaces being different means that they access different arrays.
177       if (!TargetsSpace.has_equal_tuples(AccRelSpace))
178         continue;
179 
180       AccRel = AccRel.intersect_domain(give(Acc->getStatement()->getDomain()));
181       AccRel = AccRel.intersect_params(give(S->getContext()));
182       auto CommonElt = Targets.intersect(AccRel);
183       if (!CommonElt.is_empty())
184         return Acc;
185     }
186     assert(Stmt->isRegionStmt() &&
187            "To must be encountered in block statements");
188     return nullptr;
189   }
190 
191   /// Remove writes that are overwritten unconditionally later in the same
192   /// statement.
193   ///
194   /// There must be no read of the same value between the write (that is to be
195   /// removed) and the overwrite.
196   void removeOverwrites() {
197     for (auto &Stmt : *S) {
198       auto Domain = give(Stmt.getDomain());
199       isl::union_map WillBeOverwritten =
200           isl::union_map::empty(give(S->getParamSpace()));
201 
202       SmallVector<MemoryAccess *, 32> Accesses(getAccessesInOrder(Stmt));
203 
204       // Iterate in reverse order, so the overwrite comes before the write that
205       // is to be removed.
206       for (auto *MA : reverse(Accesses)) {
207 
208         // In region statements, the explicit accesses can be in blocks that are
209         // can be executed in any order. We therefore process only the implicit
210         // writes and stop after that.
211         if (Stmt.isRegionStmt() && isExplicitAccess(MA))
212           break;
213 
214         auto AccRel = MA->getAccessRelation();
215         AccRel = AccRel.intersect_domain(Domain);
216         AccRel = AccRel.intersect_params(give(S->getContext()));
217 
218         // If a value is read in-between, do not consider it as overwritten.
219         if (MA->isRead()) {
220           WillBeOverwritten = WillBeOverwritten.subtract(AccRel);
221           continue;
222         }
223 
224         // If all of a write's elements are overwritten, remove it.
225         isl::union_map AccRelUnion = AccRel;
226         if (AccRelUnion.is_subset(WillBeOverwritten)) {
227           DEBUG(dbgs() << "Removing " << MA
228                        << " which will be overwritten anyway\n");
229 
230           Stmt.removeSingleMemoryAccess(MA);
231           OverwritesRemoved++;
232           TotalOverwritesRemoved++;
233         }
234 
235         // Unconditional writes overwrite other values.
236         if (MA->isMustWrite())
237           WillBeOverwritten = WillBeOverwritten.add_map(AccRel);
238       }
239     }
240   }
241 
242   /// Remove writes that just write the same value already stored in the
243   /// element.
244   void removeRedundantWrites() {
245     // Delay actual removal to not invalidate iterators.
246     SmallVector<MemoryAccess *, 8> StoresToRemove;
247 
248     for (auto &Stmt : *S) {
249       for (auto *WA : Stmt) {
250         if (!WA->isMustWrite())
251           continue;
252         if (!WA->isLatestArrayKind())
253           continue;
254         if (!isa<StoreInst>(WA->getAccessInstruction()) && !WA->isPHIKind())
255           continue;
256 
257         llvm::Value *ReadingValue = WA->tryGetValueStored();
258 
259         if (!ReadingValue)
260           continue;
261 
262         auto RA = getReadAccessForValue(&Stmt, ReadingValue);
263         if (!RA)
264           continue;
265         if (!RA->isLatestArrayKind())
266           continue;
267 
268         auto WARel = WA->getLatestAccessRelation();
269         WARel = WARel.intersect_domain(give(WA->getStatement()->getDomain()));
270         WARel = WARel.intersect_params(give(S->getContext()));
271         auto RARel = RA->getLatestAccessRelation();
272         RARel = RARel.intersect_domain(give(RA->getStatement()->getDomain()));
273         RARel = RARel.intersect_params(give(S->getContext()));
274 
275         if (!RARel.is_equal(WARel)) {
276           PairUnequalAccRels++;
277           DEBUG(dbgs() << "Not cleaning up " << WA
278                        << " because of unequal access relations:\n");
279           DEBUG(dbgs() << "      RA: " << RARel << "\n");
280           DEBUG(dbgs() << "      WA: " << WARel << "\n");
281           continue;
282         }
283 
284         if (auto *Conflicting = hasWriteBetween(&Stmt, RA, WA, WARel)) {
285           (void)Conflicting;
286           InBetweenStore++;
287           DEBUG(dbgs() << "Not cleaning up " << WA
288                        << " because there is another store to the same element "
289                           "between\n");
290           DEBUG(Conflicting->print(dbgs()));
291           continue;
292         }
293 
294         StoresToRemove.push_back(WA);
295       }
296     }
297 
298     for (auto *WA : StoresToRemove) {
299       auto Stmt = WA->getStatement();
300       auto AccRel = WA->getAccessRelation();
301       auto AccVal = WA->getAccessValue();
302 
303       DEBUG(dbgs() << "Cleanup of " << WA << ":\n");
304       DEBUG(dbgs() << "      Scalar: " << *AccVal << "\n");
305       DEBUG(dbgs() << "      AccRel: " << AccRel << "\n");
306       (void)AccVal;
307       (void)AccRel;
308 
309       Stmt->removeSingleMemoryAccess(WA);
310 
311       RedundantWritesRemoved++;
312       TotalRedundantWritesRemoved++;
313     }
314   }
315 
316   /// Remove statements without side effects.
317   void removeUnnecessaryStmts() {
318     auto NumStmtsBefore = S->getSize();
319     S->simplifySCoP(true);
320     assert(NumStmtsBefore >= S->getSize());
321     StmtsRemoved = NumStmtsBefore - S->getSize();
322     DEBUG(dbgs() << "Removed " << StmtsRemoved << " (of " << NumStmtsBefore
323                  << ") statements\n");
324     TotalStmtsRemoved += StmtsRemoved;
325   }
326 
327   /// Remove accesses that have an empty domain.
328   void removeEmptyPartialAccesses() {
329     for (ScopStmt &Stmt : *S) {
330       // Defer the actual removal to not invalidate iterators.
331       SmallVector<MemoryAccess *, 8> DeferredRemove;
332 
333       for (MemoryAccess *MA : Stmt) {
334         if (!MA->isWrite())
335           continue;
336 
337         isl::map AccRel = MA->getAccessRelation();
338         if (!AccRel.is_empty().is_true())
339           continue;
340 
341         DEBUG(dbgs() << "Removing " << MA
342                      << " because it's a partial access that never occurs\n");
343         DeferredRemove.push_back(MA);
344       }
345 
346       for (MemoryAccess *MA : DeferredRemove) {
347         Stmt.removeSingleMemoryAccess(MA);
348         EmptyPartialAccessesRemoved++;
349         TotalEmptyPartialAccessesRemoved++;
350       }
351     }
352   }
353 
354   /// Mark all reachable instructions and access, and sweep those that are not
355   /// reachable.
356   void markAndSweep(LoopInfo *LI) {
357     DenseSet<MemoryAccess *> UsedMA;
358     DenseSet<VirtualInstruction> UsedInsts;
359 
360     // Get all reachable instructions and accesses.
361     markReachable(S, LI, UsedInsts, UsedMA);
362 
363     // Remove all non-reachable accesses.
364     // We need get all MemoryAccesses first, in order to not invalidate the
365     // iterators when removing them.
366     SmallVector<MemoryAccess *, 64> AllMAs;
367     for (ScopStmt &Stmt : *S)
368       AllMAs.append(Stmt.begin(), Stmt.end());
369 
370     for (MemoryAccess *MA : AllMAs) {
371       if (UsedMA.count(MA))
372         continue;
373       DEBUG(dbgs() << "Removing " << MA << " because its value is not used\n");
374       ScopStmt *Stmt = MA->getStatement();
375       Stmt->removeSingleMemoryAccess(MA);
376 
377       DeadAccessesRemoved++;
378       TotalDeadAccessesRemoved++;
379     }
380 
381     // Remove all non-reachable instructions.
382     for (ScopStmt &Stmt : *S) {
383       if (!Stmt.isBlockStmt())
384         continue;
385 
386       SmallVector<Instruction *, 32> AllInsts(Stmt.insts_begin(),
387                                               Stmt.insts_end());
388       SmallVector<Instruction *, 32> RemainInsts;
389 
390       for (Instruction *Inst : AllInsts) {
391         auto It = UsedInsts.find({&Stmt, Inst});
392         if (It == UsedInsts.end()) {
393           DEBUG(dbgs() << "Removing "; Inst->print(dbgs());
394                 dbgs() << " because it is not used\n");
395           DeadInstructionsRemoved++;
396           TotalDeadInstructionsRemoved++;
397           continue;
398         }
399 
400         RemainInsts.push_back(Inst);
401 
402         // If instructions appear multiple times, keep only the first.
403         UsedInsts.erase(It);
404       }
405 
406       // Set the new instruction list to be only those we did not remove.
407       Stmt.setInstructions(RemainInsts);
408     }
409   }
410 
411   /// Print simplification statistics to @p OS.
412   void printStatistics(llvm::raw_ostream &OS, int Indent = 0) const {
413     OS.indent(Indent) << "Statistics {\n";
414     OS.indent(Indent + 4) << "Overwrites removed: " << OverwritesRemoved
415                           << '\n';
416     OS.indent(Indent + 4) << "Redundant writes removed: "
417                           << RedundantWritesRemoved << "\n";
418     OS.indent(Indent + 4) << "Accesses with empty domains removed: "
419                           << EmptyPartialAccessesRemoved << "\n";
420     OS.indent(Indent + 4) << "Dead accesses removed: " << DeadAccessesRemoved
421                           << '\n';
422     OS.indent(Indent + 4) << "Dead instructions removed: "
423                           << DeadInstructionsRemoved << '\n';
424     OS.indent(Indent + 4) << "Stmts removed: " << StmtsRemoved << "\n";
425     OS.indent(Indent) << "}\n";
426   }
427 
428   /// Print the current state of all MemoryAccesses to @p OS.
429   void printAccesses(llvm::raw_ostream &OS, int Indent = 0) const {
430     OS.indent(Indent) << "After accesses {\n";
431     for (auto &Stmt : *S) {
432       OS.indent(Indent + 4) << Stmt.getBaseName() << "\n";
433       for (auto *MA : Stmt)
434         MA->print(OS);
435     }
436     OS.indent(Indent) << "}\n";
437   }
438 
439 public:
440   static char ID;
441   explicit Simplify() : ScopPass(ID) {}
442 
443   virtual void getAnalysisUsage(AnalysisUsage &AU) const override {
444     AU.addRequiredTransitive<ScopInfoRegionPass>();
445     AU.addRequired<LoopInfoWrapperPass>();
446     AU.setPreservesAll();
447   }
448 
449   virtual bool runOnScop(Scop &S) override {
450     // Reset statistics of last processed SCoP.
451     releaseMemory();
452     assert(!isModified());
453 
454     // Prepare processing of this SCoP.
455     this->S = &S;
456     ScopsProcessed++;
457 
458     DEBUG(dbgs() << "Removing partial writes that never happen...\n");
459     removeEmptyPartialAccesses();
460 
461     DEBUG(dbgs() << "Removing overwrites...\n");
462     removeOverwrites();
463 
464     DEBUG(dbgs() << "Removing redundant writes...\n");
465     removeRedundantWrites();
466 
467     DEBUG(dbgs() << "Cleanup unused accesses...\n");
468     LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
469     markAndSweep(LI);
470 
471     DEBUG(dbgs() << "Removing statements without side effects...\n");
472     removeUnnecessaryStmts();
473 
474     if (isModified())
475       ScopsModified++;
476     DEBUG(dbgs() << "\nFinal Scop:\n");
477     DEBUG(dbgs() << S);
478 
479     return false;
480   }
481 
482   virtual void printScop(raw_ostream &OS, Scop &S) const override {
483     assert(&S == this->S &&
484            "Can only print analysis for the last processed SCoP");
485     printStatistics(OS);
486 
487     if (!isModified()) {
488       OS << "SCoP could not be simplified\n";
489       return;
490     }
491     printAccesses(OS);
492   }
493 
494   virtual void releaseMemory() override {
495     S = nullptr;
496 
497     OverwritesRemoved = 0;
498     RedundantWritesRemoved = 0;
499     EmptyPartialAccessesRemoved = 0;
500     DeadAccessesRemoved = 0;
501     DeadInstructionsRemoved = 0;
502     StmtsRemoved = 0;
503   }
504 };
505 
506 char Simplify::ID;
507 } // anonymous namespace
508 
509 Pass *polly::createSimplifyPass() { return new Simplify(); }
510 
511 INITIALIZE_PASS_BEGIN(Simplify, "polly-simplify", "Polly - Simplify", false,
512                       false)
513 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
514 INITIALIZE_PASS_END(Simplify, "polly-simplify", "Polly - Simplify", false,
515                     false)
516