1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements utilities for working with "normalized" expressions.
11 // See the comments at the top of ScalarEvolutionNormalization.h for details.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
18 using namespace llvm;
19 
20 /// TransformKind - Different types of transformations that
21 /// TransformForPostIncUse can do.
22 enum TransformKind {
23   /// Normalize - Normalize according to the given loops.
24   Normalize,
25   /// Denormalize - Perform the inverse transform on the expression with the
26   /// given loop set.
27   Denormalize
28 };
29 
30 typedef DenseMap<const SCEV *, const SCEV *> NormalizedCacheTy;
31 
32 static const SCEV *transformSubExpr(const TransformKind Kind,
33                                     NormalizePredTy Pred, ScalarEvolution &SE,
34                                     NormalizedCacheTy &Cache, const SCEV *S);
35 
36 /// Implement post-inc transformation for all valid expression types.
37 static const SCEV *transformImpl(const TransformKind Kind, NormalizePredTy Pred,
38                                  ScalarEvolution &SE, NormalizedCacheTy &Cache,
39                                  const SCEV *S) {
40   if (const SCEVCastExpr *X = dyn_cast<SCEVCastExpr>(S)) {
41     const SCEV *O = X->getOperand();
42     const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O);
43     if (O != N)
44       switch (S->getSCEVType()) {
45       case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType());
46       case scSignExtend: return SE.getSignExtendExpr(N, S->getType());
47       case scTruncate: return SE.getTruncateExpr(N, S->getType());
48       default: llvm_unreachable("Unexpected SCEVCastExpr kind!");
49       }
50     return S;
51   }
52 
53   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
54     // An addrec. This is the interesting part.
55     SmallVector<const SCEV *, 8> Operands;
56 
57     transform(AR->operands(), std::back_inserter(Operands),
58               [&](const SCEV *Op) {
59                 return transformSubExpr(Kind, Pred, SE, Cache, Op);
60               });
61 
62     // Conservatively use AnyWrap until/unless we need FlagNW.
63     const SCEV *Result =
64         SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
65     switch (Kind) {
66     case Normalize:
67       // We want to normalize step expression, because otherwise we might not be
68       // able to denormalize to the original expression.
69       //
70       // Here is an example what will happen if we don't normalize step:
71       //  ORIGINAL ISE:
72       //    {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25>
73       //  NORMALIZED ISE:
74       //    {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+,
75       //     (100 /u {0,+,1}<%bb16>)}<%bb25>
76       //  DENORMALIZED BACK ISE:
77       //    {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+,
78       //     (100 /u {1,+,1}<%bb16>)}<%bb25>
79       //  Note that the initial value changes after normalization +
80       //  denormalization, which isn't correct.
81       if (Pred(AR)) {
82         const SCEV *TransformedStep =
83             transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE));
84         Result = SE.getMinusSCEV(Result, TransformedStep);
85       }
86       break;
87     case Denormalize:
88       // Here we want to normalize step expressions for the same reasons, as
89       // stated above.
90       if (Pred(AR)) {
91         const SCEV *TransformedStep =
92             transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE));
93         Result = SE.getAddExpr(Result, TransformedStep);
94       }
95       break;
96     }
97     return Result;
98   }
99 
100   if (const SCEVNAryExpr *X = dyn_cast<SCEVNAryExpr>(S)) {
101     SmallVector<const SCEV *, 8> Operands;
102     bool Changed = false;
103     // Transform each operand.
104     for (auto *O : X->operands()) {
105       const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O);
106       Changed |= N != O;
107       Operands.push_back(N);
108     }
109     // If any operand actually changed, return a transformed result.
110     if (Changed)
111       switch (S->getSCEVType()) {
112       case scAddExpr: return SE.getAddExpr(Operands);
113       case scMulExpr: return SE.getMulExpr(Operands);
114       case scSMaxExpr: return SE.getSMaxExpr(Operands);
115       case scUMaxExpr: return SE.getUMaxExpr(Operands);
116       default: llvm_unreachable("Unexpected SCEVNAryExpr kind!");
117       }
118     return S;
119   }
120 
121   if (const SCEVUDivExpr *X = dyn_cast<SCEVUDivExpr>(S)) {
122     const SCEV *LO = X->getLHS();
123     const SCEV *RO = X->getRHS();
124     const SCEV *LN = transformSubExpr(Kind, Pred, SE, Cache, LO);
125     const SCEV *RN = transformSubExpr(Kind, Pred, SE, Cache, RO);
126     if (LO != LN || RO != RN)
127       return SE.getUDivExpr(LN, RN);
128     return S;
129   }
130 
131   llvm_unreachable("Unexpected SCEV kind!");
132 }
133 
134 /// Manage recursive transformation across an expression DAG. Revisiting
135 /// expressions would lead to exponential recursion.
136 static const SCEV *transformSubExpr(const TransformKind Kind,
137                                     NormalizePredTy Pred, ScalarEvolution &SE,
138                                     NormalizedCacheTy &Cache, const SCEV *S) {
139   if (isa<SCEVConstant>(S) || isa<SCEVUnknown>(S))
140     return S;
141 
142   const SCEV *Result = Cache.lookup(S);
143   if (Result)
144     return Result;
145 
146   Result = transformImpl(Kind, Pred, SE, Cache, S);
147   Cache[S] = Result;
148   return Result;
149 }
150 
151 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
152                                          const PostIncLoopSet &Loops,
153                                          ScalarEvolution &SE) {
154   auto Pred = [&](const SCEVAddRecExpr *AR) {
155     return Loops.count(AR->getLoop());
156   };
157   NormalizedCacheTy Cache;
158   return transformSubExpr(Normalize, Pred, SE, Cache, S);
159 }
160 
161 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
162                                            ScalarEvolution &SE) {
163   NormalizedCacheTy Cache;
164   return transformSubExpr(Normalize, Pred, SE, Cache, S);
165 }
166 
167 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
168                                            const PostIncLoopSet &Loops,
169                                            ScalarEvolution &SE) {
170   auto Pred = [&](const SCEVAddRecExpr *AR) {
171     return Loops.count(AR->getLoop());
172   };
173   NormalizedCacheTy Cache;
174   return transformSubExpr(Denormalize, Pred, SE, Cache, S);
175 }
176