1 //===------ Simplify.cpp ----------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Simplify a SCoP by removing unnecessary statements and accesses.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/Simplify.h"
15 #include "polly/ScopInfo.h"
16 #include "polly/ScopPass.h"
17 #include "polly/Support/GICHelper.h"
18 #include "polly/Support/ISLOStream.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/Support/Debug.h"
21 #define DEBUG_TYPE "polly-simplify"
22 
23 using namespace llvm;
24 using namespace polly;
25 
26 namespace {
27 
28 STATISTIC(ScopsProcessed, "Number of SCoPs processed");
29 STATISTIC(ScopsModified, "Number of SCoPs simplified");
30 
31 STATISTIC(PairUnequalAccRels, "Number of Load-Store pairs NOT removed because "
32                               "of different access relations");
33 STATISTIC(InBetweenStore, "Number of Load-Store pairs NOT removed because "
34                           "there is another store between them");
35 STATISTIC(TotalOverwritesRemoved, "Number of removed overwritten writes");
36 STATISTIC(TotalRedundantWritesRemoved,
37           "Number of writes of same value removed in any SCoP");
38 STATISTIC(TotalStmtsRemoved, "Number of statements removed in any SCoP");
39 
40 static bool isImplicitRead(MemoryAccess *MA) {
41   return MA->isRead() && MA->isOriginalScalarKind();
42 }
43 
44 static bool isExplicitAccess(MemoryAccess *MA) {
45   return MA->isOriginalArrayKind();
46 }
47 
48 static bool isImplicitWrite(MemoryAccess *MA) {
49   return MA->isWrite() && MA->isOriginalScalarKind();
50 }
51 
52 /// Return a vector that contains MemoryAccesses in the order in
53 /// which they are executed.
54 ///
55 /// The order is:
56 /// - Implicit reads (BlockGenerator::generateScalarLoads)
57 /// - Explicit reads and writes (BlockGenerator::generateArrayLoad,
58 ///   BlockGenerator::generateArrayStore)
59 ///   - In block statements, the accesses are in order in which their
60 ///     instructions are executed.
61 ///   - In region statements, that order of execution is not predictable at
62 ///     compile-time.
63 /// - Implicit writes (BlockGenerator::generateScalarStores)
64 ///   The order in which implicit writes are executed relative to each other is
65 ///   undefined.
66 static SmallVector<MemoryAccess *, 32> getAccessesInOrder(ScopStmt &Stmt) {
67 
68   SmallVector<MemoryAccess *, 32> Accesses;
69 
70   for (MemoryAccess *MemAcc : Stmt)
71     if (isImplicitRead(MemAcc))
72       Accesses.push_back(MemAcc);
73 
74   for (MemoryAccess *MemAcc : Stmt)
75     if (isExplicitAccess(MemAcc))
76       Accesses.push_back(MemAcc);
77 
78   for (MemoryAccess *MemAcc : Stmt)
79     if (isImplicitWrite(MemAcc))
80       Accesses.push_back(MemAcc);
81 
82   return Accesses;
83 }
84 
85 class Simplify : public ScopPass {
86 private:
87   /// The last/current SCoP that is/has been processed.
88   Scop *S;
89 
90   /// Number of writes that are overwritten anyway.
91   int OverwritesRemoved = 0;
92 
93   /// Number of redundant writes removed from this SCoP.
94   int RedundantWritesRemoved = 0;
95 
96   /// Number of unnecessary statements removed from the SCoP.
97   int StmtsRemoved = 0;
98 
99   /// Return whether at least one simplification has been applied.
100   bool isModified() const {
101     return OverwritesRemoved > 0 || RedundantWritesRemoved > 0 ||
102            StmtsRemoved > 0;
103   }
104 
105   MemoryAccess *getReadAccessForValue(ScopStmt *Stmt, llvm::Value *Val) {
106     if (!isa<Instruction>(Val))
107       return nullptr;
108 
109     for (auto *MA : *Stmt) {
110       if (!MA->isRead())
111         continue;
112       if (MA->getAccessValue() != Val)
113         continue;
114 
115       return MA;
116     }
117 
118     return nullptr;
119   }
120 
121   /// Return a write access that occurs between @p From and @p To.
122   ///
123   /// In region statements the order is ignored because we cannot predict it.
124   ///
125   /// @param Stmt    Statement of both writes.
126   /// @param From    Start looking after this access.
127   /// @param To      Stop looking at this access, with the access itself.
128   /// @param Targets Look for an access that may wrote to one of these elements.
129   ///
130   /// @return A write access between @p From and @p To that writes to at least
131   ///         one element in @p Targets.
132   MemoryAccess *hasWriteBetween(ScopStmt *Stmt, MemoryAccess *From,
133                                 MemoryAccess *To, isl::map Targets) {
134     auto TargetsSpace = Targets.get_space();
135 
136     bool Started = Stmt->isRegionStmt();
137     auto Accesses = getAccessesInOrder(*Stmt);
138     for (auto *Acc : Accesses) {
139       if (Acc->isLatestScalarKind())
140         continue;
141 
142       if (Stmt->isBlockStmt() && From == Acc) {
143         assert(!Started);
144         Started = true;
145         continue;
146       }
147       if (Stmt->isBlockStmt() && To == Acc) {
148         assert(Started);
149         return nullptr;
150       }
151       if (!Started)
152         continue;
153 
154       if (!Acc->isWrite())
155         continue;
156 
157       auto AccRel = give(Acc->getAccessRelation());
158       auto AccRelSpace = AccRel.get_space();
159 
160       // Spaces being different means that they access different arrays.
161       if (!TargetsSpace.has_equal_tuples(AccRelSpace))
162         continue;
163 
164       AccRel = AccRel.intersect_domain(give(Acc->getStatement()->getDomain()));
165       AccRel = AccRel.intersect_params(give(S->getContext()));
166       auto CommonElt = Targets.intersect(AccRel);
167       if (!CommonElt.is_empty())
168         return Acc;
169     }
170     assert(Stmt->isRegionStmt() &&
171            "To must be encountered in block statements");
172     return nullptr;
173   }
174 
175   /// Remove writes that are overwritten unconditionally later in the same
176   /// statement.
177   ///
178   /// There must be no read of the same value between the write (that is to be
179   /// removed) and the overwrite.
180   void removeOverwrites() {
181     for (auto &Stmt : *S) {
182       auto Domain = give(Stmt.getDomain());
183       isl::union_map WillBeOverwritten =
184           isl::union_map::empty(give(S->getParamSpace()));
185 
186       SmallVector<MemoryAccess *, 32> Accesses(getAccessesInOrder(Stmt));
187 
188       // Iterate in reverse order, so the overwrite comes before the write that
189       // is to be removed.
190       for (auto *MA : reverse(Accesses)) {
191 
192         // In region statements, the explicit accesses can be in blocks that are
193         // can be executed in any order. We therefore process only the implicit
194         // writes and stop after that.
195         if (Stmt.isRegionStmt() && isExplicitAccess(MA))
196           break;
197 
198         auto AccRel = give(MA->getAccessRelation());
199         AccRel = AccRel.intersect_domain(Domain);
200         AccRel = AccRel.intersect_params(give(S->getContext()));
201 
202         // If a value is read in-between, do not consider it as overwritten.
203         if (MA->isRead()) {
204           WillBeOverwritten = WillBeOverwritten.subtract(AccRel);
205           continue;
206         }
207 
208         // If all of a write's elements are overwritten, remove it.
209         isl::union_map AccRelUnion = AccRel;
210         if (AccRelUnion.is_subset(WillBeOverwritten)) {
211           DEBUG(dbgs() << "Removing " << MA
212                        << " which will be overwritten anyway\n");
213 
214           Stmt.removeSingleMemoryAccess(MA);
215           OverwritesRemoved++;
216           TotalOverwritesRemoved++;
217         }
218 
219         // Unconditional writes overwrite other values.
220         if (MA->isMustWrite())
221           WillBeOverwritten = WillBeOverwritten.add_map(AccRel);
222       }
223     }
224   }
225 
226   /// Remove writes that just write the same value already stored in the
227   /// element.
228   void removeRedundantWrites() {
229     // Delay actual removal to not invalidate iterators.
230     SmallVector<MemoryAccess *, 8> StoresToRemove;
231 
232     for (auto &Stmt : *S) {
233       for (auto *WA : Stmt) {
234         if (!WA->isMustWrite())
235           continue;
236         if (!WA->isLatestArrayKind())
237           continue;
238         if (!isa<StoreInst>(WA->getAccessInstruction()) && !WA->isPHIKind())
239           continue;
240 
241         auto ReadingValue = WA->getAccessValue();
242 
243         if (WA->isPHIKind()) {
244           PHINode *PHI = cast<PHINode>(WA->getAccessValue());
245           BasicBlock *BB = Stmt.getBasicBlock();
246           ReadingValue = PHI->getIncomingValueForBlock(BB);
247         }
248 
249         if (!ReadingValue)
250           continue;
251 
252         auto RA = getReadAccessForValue(&Stmt, ReadingValue);
253         if (!RA)
254           continue;
255         if (!RA->isLatestArrayKind())
256           continue;
257 
258         auto WARel = give(WA->getLatestAccessRelation());
259         WARel = WARel.intersect_domain(give(WA->getStatement()->getDomain()));
260         WARel = WARel.intersect_params(give(S->getContext()));
261         auto RARel = give(RA->getLatestAccessRelation());
262         RARel = RARel.intersect_domain(give(RA->getStatement()->getDomain()));
263         RARel = RARel.intersect_params(give(S->getContext()));
264 
265         if (!RARel.is_equal(WARel)) {
266           PairUnequalAccRels++;
267           DEBUG(dbgs() << "Not cleaning up " << WA
268                        << " because of unequal access relations:\n");
269           DEBUG(dbgs() << "      RA: " << RARel << "\n");
270           DEBUG(dbgs() << "      WA: " << WARel << "\n");
271           continue;
272         }
273 
274         if (auto *Conflicting = hasWriteBetween(&Stmt, RA, WA, WARel)) {
275           (void)Conflicting;
276           InBetweenStore++;
277           DEBUG(dbgs() << "Not cleaning up " << WA
278                        << " because there is another store to the same element "
279                           "between\n");
280           DEBUG(Conflicting->print(dbgs()));
281           continue;
282         }
283 
284         StoresToRemove.push_back(WA);
285       }
286     }
287 
288     for (auto *WA : StoresToRemove) {
289       auto Stmt = WA->getStatement();
290       auto AccRel = give(WA->getAccessRelation());
291       auto AccVal = WA->getAccessValue();
292 
293       DEBUG(dbgs() << "Cleanup of " << WA << ":\n");
294       DEBUG(dbgs() << "      Scalar: " << *AccVal << "\n");
295       DEBUG(dbgs() << "      AccRel: " << AccRel << "\n");
296       (void)AccVal;
297       (void)AccRel;
298 
299       Stmt->removeSingleMemoryAccess(WA);
300 
301       RedundantWritesRemoved++;
302       TotalRedundantWritesRemoved++;
303     }
304   }
305 
306   /// Remove statements without side effects.
307   void removeUnnecessayStmts() {
308     auto NumStmtsBefore = S->getSize();
309     S->simplifySCoP(true);
310     assert(NumStmtsBefore >= S->getSize());
311     StmtsRemoved = NumStmtsBefore - S->getSize();
312     DEBUG(dbgs() << "Removed " << StmtsRemoved << " (of " << NumStmtsBefore
313                  << ") statements\n");
314     TotalStmtsRemoved += StmtsRemoved;
315   }
316 
317   /// Print simplification statistics to @p OS.
318   void printStatistics(llvm::raw_ostream &OS, int Indent = 0) const {
319     OS.indent(Indent) << "Statistics {\n";
320     OS.indent(Indent + 4) << "Overwrites removed: " << OverwritesRemoved
321                           << '\n';
322     OS.indent(Indent + 4) << "Redundant writes removed: "
323                           << RedundantWritesRemoved << "\n";
324     OS.indent(Indent + 4) << "Stmts removed: " << StmtsRemoved << "\n";
325     OS.indent(Indent) << "}\n";
326   }
327 
328   /// Print the current state of all MemoryAccesses to @p OS.
329   void printAccesses(llvm::raw_ostream &OS, int Indent = 0) const {
330     OS.indent(Indent) << "After accesses {\n";
331     for (auto &Stmt : *S) {
332       OS.indent(Indent + 4) << Stmt.getBaseName() << "\n";
333       for (auto *MA : Stmt)
334         MA->print(OS);
335     }
336     OS.indent(Indent) << "}\n";
337   }
338 
339 public:
340   static char ID;
341   explicit Simplify() : ScopPass(ID) {}
342 
343   virtual void getAnalysisUsage(AnalysisUsage &AU) const override {
344     AU.addRequiredTransitive<ScopInfoRegionPass>();
345     AU.setPreservesAll();
346   }
347 
348   virtual bool runOnScop(Scop &S) override {
349     // Reset statistics of last processed SCoP.
350     releaseMemory();
351 
352     // Prepare processing of this SCoP.
353     this->S = &S;
354     ScopsProcessed++;
355 
356     DEBUG(dbgs() << "Removing overwrites...\n");
357     removeOverwrites();
358 
359     DEBUG(dbgs() << "Removing redundant writes...\n");
360     removeRedundantWrites();
361 
362     DEBUG(dbgs() << "Removing statements without side effects...\n");
363     removeUnnecessayStmts();
364 
365     if (isModified())
366       ScopsModified++;
367     DEBUG(dbgs() << "\nFinal Scop:\n");
368     DEBUG(S.print(dbgs()));
369 
370     return false;
371   }
372 
373   virtual void printScop(raw_ostream &OS, Scop &S) const override {
374     assert(&S == this->S &&
375            "Can only print analysis for the last processed SCoP");
376     printStatistics(OS);
377 
378     if (!isModified()) {
379       OS << "SCoP could not be simplified\n";
380       return;
381     }
382     printAccesses(OS);
383   }
384 
385   virtual void releaseMemory() override {
386     S = nullptr;
387 
388     OverwritesRemoved = 0;
389     RedundantWritesRemoved = 0;
390     StmtsRemoved = 0;
391   }
392 };
393 
394 char Simplify::ID;
395 } // anonymous namespace
396 
397 Pass *polly::createSimplifyPass() { return new Simplify(); }
398 
399 INITIALIZE_PASS_BEGIN(Simplify, "polly-simplify", "Polly - Simplify", false,
400                       false)
401 INITIALIZE_PASS_END(Simplify, "polly-simplify", "Polly - Simplify", false,
402                     false)
403