1 //===------ IslExprBuilder.cpp ----- Code generate isl AST expressions ----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "polly/CodeGen/IslExprBuilder.h"
13 
14 #include "polly/ScopInfo.h"
15 #include "polly/Support/GICHelper.h"
16 
17 #include "llvm/Analysis/ScalarEvolutionExpander.h"
18 #include "llvm/Support/Debug.h"
19 
20 using namespace llvm;
21 using namespace polly;
22 
23 Type *IslExprBuilder::getWidestType(Type *T1, Type *T2) {
24   assert(isa<IntegerType>(T1) && isa<IntegerType>(T2));
25 
26   if (T1->getPrimitiveSizeInBits() < T2->getPrimitiveSizeInBits())
27     return T2;
28   else
29     return T1;
30 }
31 
32 Value *IslExprBuilder::createOpUnary(__isl_take isl_ast_expr *Expr) {
33   assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_minus &&
34          "Unsupported unary operation");
35 
36   Value *V;
37   Type *MaxType = getType(Expr);
38 
39   V = create(isl_ast_expr_get_op_arg(Expr, 0));
40   MaxType = getWidestType(MaxType, V->getType());
41 
42   if (MaxType != V->getType())
43     V = Builder.CreateSExt(V, MaxType);
44 
45   isl_ast_expr_free(Expr);
46   return Builder.CreateNSWNeg(V);
47 }
48 
49 Value *IslExprBuilder::createOpNAry(__isl_take isl_ast_expr *Expr) {
50   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
51          "isl ast expression not of type isl_ast_op");
52   assert(isl_ast_expr_get_op_n_arg(Expr) >= 2 &&
53          "We need at least two operands in an n-ary operation");
54 
55   Value *V;
56 
57   V = create(isl_ast_expr_get_op_arg(Expr, 0));
58 
59   for (int i = 0; i < isl_ast_expr_get_op_n_arg(Expr); ++i) {
60     Value *OpV;
61     OpV = create(isl_ast_expr_get_op_arg(Expr, i));
62 
63     Type *Ty = getWidestType(V->getType(), OpV->getType());
64 
65     if (Ty != OpV->getType())
66       OpV = Builder.CreateSExt(OpV, Ty);
67 
68     if (Ty != V->getType())
69       V = Builder.CreateSExt(V, Ty);
70 
71     switch (isl_ast_expr_get_op_type(Expr)) {
72     default:
73       llvm_unreachable("This is no n-ary isl ast expression");
74 
75     case isl_ast_op_max: {
76       Value *Cmp = Builder.CreateICmpSGT(V, OpV);
77       V = Builder.CreateSelect(Cmp, V, OpV);
78       continue;
79     }
80     case isl_ast_op_min: {
81       Value *Cmp = Builder.CreateICmpSLT(V, OpV);
82       V = Builder.CreateSelect(Cmp, V, OpV);
83       continue;
84     }
85     }
86   }
87 
88   // TODO: We can truncate the result, if it fits into a smaller type. This can
89   // help in cases where we have larger operands (e.g. i67) but the result is
90   // known to fit into i64. Without the truncation, the larger i67 type may
91   // force all subsequent operations to be performed on a non-native type.
92   isl_ast_expr_free(Expr);
93   return V;
94 }
95 
96 Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) {
97   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
98          "isl ast expression not of type isl_ast_op");
99   assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_access &&
100          "not an access isl ast expression");
101   assert(isl_ast_expr_get_op_n_arg(Expr) >= 2 &&
102          "We need at least two operands to create a member access.");
103 
104   Value *Base, *IndexOp, *Access;
105   isl_ast_expr *BaseExpr;
106   isl_id *BaseId;
107 
108   BaseExpr = isl_ast_expr_get_op_arg(Expr, 0);
109   BaseId = isl_ast_expr_get_id(BaseExpr);
110   isl_ast_expr_free(BaseExpr);
111 
112   const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(BaseId);
113   Base = SAI->getBasePtr();
114   assert(Base->getType()->isPointerTy() && "Access base should be a pointer");
115   const Twine &BaseName = Base->getName();
116 
117   if (Base->getType() != SAI->getType())
118     Base = Builder.CreateBitCast(Base, SAI->getType(),
119                                  "polly.access.cast." + BaseName);
120 
121   IndexOp = nullptr;
122   for (unsigned u = 1, e = isl_ast_expr_get_op_n_arg(Expr); u < e; u++) {
123     Value *NextIndex = create(isl_ast_expr_get_op_arg(Expr, u));
124     assert(NextIndex->getType()->isIntegerTy() &&
125            "Access index should be an integer");
126 
127     if (!IndexOp)
128       IndexOp = NextIndex;
129     else
130       IndexOp = Builder.CreateAdd(IndexOp, NextIndex);
131 
132     // For every but the last dimension multiply the size, for the last
133     // dimension we can exit the loop.
134     if (u + 1 >= e)
135       break;
136 
137     const SCEV *DimSCEV = SAI->getDimensionSize(u - 1);
138     Value *DimSize = Expander.expandCodeFor(DimSCEV, IndexOp->getType(),
139                                             Builder.GetInsertPoint());
140     IndexOp = Builder.CreateMul(IndexOp, DimSize);
141   }
142 
143   Access = Builder.CreateGEP(Base, IndexOp, "polly.access." + BaseName);
144 
145   isl_ast_expr_free(Expr);
146   return Access;
147 }
148 
149 Value *IslExprBuilder::createOpAccess(isl_ast_expr *Expr) {
150   Value *Addr = createAccessAddress(Expr);
151   assert(Addr && "Could not create op access address");
152   return Builder.CreateLoad(Addr, Addr->getName() + ".load");
153 }
154 
155 Value *IslExprBuilder::createOpBin(__isl_take isl_ast_expr *Expr) {
156   Value *LHS, *RHS, *Res;
157   Type *MaxType;
158   isl_ast_op_type OpType;
159 
160   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
161          "isl ast expression not of type isl_ast_op");
162   assert(isl_ast_expr_get_op_n_arg(Expr) == 2 &&
163          "not a binary isl ast expression");
164 
165   OpType = isl_ast_expr_get_op_type(Expr);
166 
167   LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
168   RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
169 
170   MaxType = LHS->getType();
171   MaxType = getWidestType(MaxType, RHS->getType());
172 
173   // Take the result into account when calculating the widest type.
174   //
175   // For operations such as '+' the result may require a type larger than
176   // the type of the individual operands. For other operations such as '/', the
177   // result type cannot be larger than the type of the individual operand. isl
178   // does not calculate correct types for these operations and we consequently
179   // exclude those operations here.
180   switch (OpType) {
181   case isl_ast_op_pdiv_q:
182   case isl_ast_op_pdiv_r:
183   case isl_ast_op_div:
184   case isl_ast_op_fdiv_q:
185     // Do nothing
186     break;
187   case isl_ast_op_add:
188   case isl_ast_op_sub:
189   case isl_ast_op_mul:
190     MaxType = getWidestType(MaxType, getType(Expr));
191     break;
192   default:
193     llvm_unreachable("This is no binary isl ast expression");
194   }
195 
196   if (MaxType != RHS->getType())
197     RHS = Builder.CreateSExt(RHS, MaxType);
198 
199   if (MaxType != LHS->getType())
200     LHS = Builder.CreateSExt(LHS, MaxType);
201 
202   switch (OpType) {
203   default:
204     llvm_unreachable("This is no binary isl ast expression");
205   case isl_ast_op_add:
206     Res = Builder.CreateNSWAdd(LHS, RHS);
207     break;
208   case isl_ast_op_sub:
209     Res = Builder.CreateNSWSub(LHS, RHS);
210     break;
211   case isl_ast_op_mul:
212     Res = Builder.CreateNSWMul(LHS, RHS);
213     break;
214   case isl_ast_op_div:
215   case isl_ast_op_pdiv_q: // Dividend is non-negative
216     Res = Builder.CreateSDiv(LHS, RHS);
217     break;
218   case isl_ast_op_fdiv_q: { // Round towards -infty
219     // TODO: Review code and check that this calculation does not yield
220     //       incorrect overflow in some bordercases.
221     //
222     // floord(n,d) ((n < 0) ? (n - d + 1) : n) / d
223     Value *One = ConstantInt::get(MaxType, 1);
224     Value *Zero = ConstantInt::get(MaxType, 0);
225     Value *Sum1 = Builder.CreateSub(LHS, RHS);
226     Value *Sum2 = Builder.CreateAdd(Sum1, One);
227     Value *isNegative = Builder.CreateICmpSLT(LHS, Zero);
228     Value *Dividend = Builder.CreateSelect(isNegative, Sum2, LHS);
229     Res = Builder.CreateSDiv(Dividend, RHS);
230     break;
231   }
232   case isl_ast_op_pdiv_r: // Dividend is non-negative
233     Res = Builder.CreateSRem(LHS, RHS);
234     break;
235   }
236 
237   // TODO: We can truncate the result, if it fits into a smaller type. This can
238   // help in cases where we have larger operands (e.g. i67) but the result is
239   // known to fit into i64. Without the truncation, the larger i67 type may
240   // force all subsequent operations to be performed on a non-native type.
241   isl_ast_expr_free(Expr);
242   return Res;
243 }
244 
245 Value *IslExprBuilder::createOpSelect(__isl_take isl_ast_expr *Expr) {
246   assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_select &&
247          "Unsupported unary isl ast expression");
248   Value *LHS, *RHS, *Cond;
249   Type *MaxType = getType(Expr);
250 
251   Cond = create(isl_ast_expr_get_op_arg(Expr, 0));
252   if (!Cond->getType()->isIntegerTy(1))
253     Cond = Builder.CreateIsNotNull(Cond);
254 
255   LHS = create(isl_ast_expr_get_op_arg(Expr, 1));
256   RHS = create(isl_ast_expr_get_op_arg(Expr, 2));
257 
258   MaxType = getWidestType(MaxType, LHS->getType());
259   MaxType = getWidestType(MaxType, RHS->getType());
260 
261   if (MaxType != RHS->getType())
262     RHS = Builder.CreateSExt(RHS, MaxType);
263 
264   if (MaxType != LHS->getType())
265     LHS = Builder.CreateSExt(LHS, MaxType);
266 
267   // TODO: Do we want to truncate the result?
268   isl_ast_expr_free(Expr);
269   return Builder.CreateSelect(Cond, LHS, RHS);
270 }
271 
272 Value *IslExprBuilder::createOpICmp(__isl_take isl_ast_expr *Expr) {
273   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
274          "Expected an isl_ast_expr_op expression");
275 
276   Value *LHS, *RHS, *Res;
277 
278   LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
279   RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
280 
281   bool IsPtrType = LHS->getType()->isPointerTy();
282   assert((!IsPtrType || RHS->getType()->isPointerTy()) &&
283          "Both ICmp operators should be pointer types or none of them");
284 
285   if (LHS->getType() != RHS->getType()) {
286     if (IsPtrType) {
287       Type *I8PtrTy = Builder.getInt8PtrTy();
288       if (LHS->getType() != I8PtrTy)
289         LHS = Builder.CreateBitCast(LHS, I8PtrTy);
290       if (RHS->getType() != I8PtrTy)
291         RHS = Builder.CreateBitCast(RHS, I8PtrTy);
292     } else {
293       Type *MaxType = LHS->getType();
294       MaxType = getWidestType(MaxType, RHS->getType());
295 
296       if (MaxType != RHS->getType())
297         RHS = Builder.CreateSExt(RHS, MaxType);
298 
299       if (MaxType != LHS->getType())
300         LHS = Builder.CreateSExt(LHS, MaxType);
301     }
302   }
303 
304   isl_ast_op_type OpType = isl_ast_expr_get_op_type(Expr);
305   assert(OpType >= isl_ast_op_eq && OpType <= isl_ast_op_gt &&
306          "Unsupported ICmp isl ast expression");
307   assert(isl_ast_op_eq + 4 == isl_ast_op_gt &&
308          "Isl ast op type interface changed");
309 
310   CmpInst::Predicate Predicates[5][2] = {
311       {CmpInst::ICMP_EQ, CmpInst::ICMP_EQ},
312       {CmpInst::ICMP_SLE, CmpInst::ICMP_ULE},
313       {CmpInst::ICMP_SLT, CmpInst::ICMP_ULT},
314       {CmpInst::ICMP_SGE, CmpInst::ICMP_UGE},
315       {CmpInst::ICMP_SGT, CmpInst::ICMP_UGT},
316   };
317 
318   Res = Builder.CreateICmp(Predicates[OpType - isl_ast_op_eq][IsPtrType], LHS,
319                            RHS);
320 
321   isl_ast_expr_free(Expr);
322   return Res;
323 }
324 
325 Value *IslExprBuilder::createOpBoolean(__isl_take isl_ast_expr *Expr) {
326   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
327          "Expected an isl_ast_expr_op expression");
328 
329   Value *LHS, *RHS, *Res;
330   isl_ast_op_type OpType;
331 
332   OpType = isl_ast_expr_get_op_type(Expr);
333 
334   assert((OpType == isl_ast_op_and || OpType == isl_ast_op_or) &&
335          "Unsupported isl_ast_op_type");
336 
337   LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
338   RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
339 
340   // Even though the isl pretty printer prints the expressions as 'exp && exp'
341   // or 'exp || exp', we actually code generate the bitwise expressions
342   // 'exp & exp' or 'exp | exp'. This forces the evaluation of both branches,
343   // but it is, due to the use of i1 types, otherwise equivalent. The reason
344   // to go for bitwise operations is, that we assume the reduced control flow
345   // will outweight the overhead introduced by evaluating unneeded expressions.
346   // The isl code generation currently does not take advantage of the fact that
347   // the expression after an '||' or '&&' is in some cases not evaluated.
348   // Evaluating it anyways does not cause any undefined behaviour.
349   //
350   // TODO: Document in isl itself, that the unconditionally evaluating the
351   // second part of '||' or '&&' expressions is safe.
352   if (!LHS->getType()->isIntegerTy(1))
353     LHS = Builder.CreateIsNotNull(LHS);
354   if (!RHS->getType()->isIntegerTy(1))
355     RHS = Builder.CreateIsNotNull(RHS);
356 
357   switch (OpType) {
358   default:
359     llvm_unreachable("Unsupported boolean expression");
360   case isl_ast_op_and:
361     Res = Builder.CreateAnd(LHS, RHS);
362     break;
363   case isl_ast_op_or:
364     Res = Builder.CreateOr(LHS, RHS);
365     break;
366   }
367 
368   isl_ast_expr_free(Expr);
369   return Res;
370 }
371 
372 Value *IslExprBuilder::createOp(__isl_take isl_ast_expr *Expr) {
373   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
374          "Expression not of type isl_ast_expr_op");
375   switch (isl_ast_expr_get_op_type(Expr)) {
376   case isl_ast_op_error:
377   case isl_ast_op_cond:
378   case isl_ast_op_and_then:
379   case isl_ast_op_or_else:
380   case isl_ast_op_call:
381   case isl_ast_op_member:
382     llvm_unreachable("Unsupported isl ast expression");
383   case isl_ast_op_access:
384     return createOpAccess(Expr);
385   case isl_ast_op_max:
386   case isl_ast_op_min:
387     return createOpNAry(Expr);
388   case isl_ast_op_add:
389   case isl_ast_op_sub:
390   case isl_ast_op_mul:
391   case isl_ast_op_div:
392   case isl_ast_op_fdiv_q: // Round towards -infty
393   case isl_ast_op_pdiv_q: // Dividend is non-negative
394   case isl_ast_op_pdiv_r: // Dividend is non-negative
395     return createOpBin(Expr);
396   case isl_ast_op_minus:
397     return createOpUnary(Expr);
398   case isl_ast_op_select:
399     return createOpSelect(Expr);
400   case isl_ast_op_and:
401   case isl_ast_op_or:
402     return createOpBoolean(Expr);
403   case isl_ast_op_eq:
404   case isl_ast_op_le:
405   case isl_ast_op_lt:
406   case isl_ast_op_ge:
407   case isl_ast_op_gt:
408     return createOpICmp(Expr);
409   case isl_ast_op_address_of:
410     return createOpAddressOf(Expr);
411   }
412 
413   llvm_unreachable("Unsupported isl_ast_expr_op kind.");
414 }
415 
416 Value *IslExprBuilder::createOpAddressOf(__isl_take isl_ast_expr *Expr) {
417   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
418          "Expected an isl_ast_expr_op expression.");
419   assert(isl_ast_expr_get_op_n_arg(Expr) == 1 && "Address of should be unary.");
420 
421   isl_ast_expr *Op = isl_ast_expr_get_op_arg(Expr, 0);
422   assert(isl_ast_expr_get_type(Op) == isl_ast_expr_op &&
423          "Expected address of operator to be an isl_ast_expr_op expression.");
424   assert(isl_ast_expr_get_op_type(Op) == isl_ast_op_access &&
425          "Expected address of operator to be an access expression.");
426 
427   Value *V = createAccessAddress(Op);
428 
429   isl_ast_expr_free(Expr);
430 
431   return V;
432 }
433 
434 Value *IslExprBuilder::createId(__isl_take isl_ast_expr *Expr) {
435   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_id &&
436          "Expression not of type isl_ast_expr_ident");
437 
438   isl_id *Id;
439   Value *V;
440 
441   Id = isl_ast_expr_get_id(Expr);
442 
443   assert(IDToValue.count(Id) && "Identifier not found");
444 
445   V = IDToValue[Id];
446 
447   isl_id_free(Id);
448   isl_ast_expr_free(Expr);
449 
450   return V;
451 }
452 
453 IntegerType *IslExprBuilder::getType(__isl_keep isl_ast_expr *Expr) {
454   // XXX: We assume i64 is large enough. This is often true, but in general
455   //      incorrect. Also, on 32bit architectures, it would be beneficial to
456   //      use a smaller type. We can and should directly derive this information
457   //      during code generation.
458   return IntegerType::get(Builder.getContext(), 64);
459 }
460 
461 Value *IslExprBuilder::createInt(__isl_take isl_ast_expr *Expr) {
462   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_int &&
463          "Expression not of type isl_ast_expr_int");
464   isl_val *Val;
465   Value *V;
466   APInt APValue;
467   IntegerType *T;
468 
469   Val = isl_ast_expr_get_val(Expr);
470   APValue = APIntFromVal(Val);
471   T = getType(Expr);
472   APValue = APValue.sextOrSelf(T->getBitWidth());
473   V = ConstantInt::get(T, APValue);
474 
475   isl_ast_expr_free(Expr);
476   return V;
477 }
478 
479 Value *IslExprBuilder::create(__isl_take isl_ast_expr *Expr) {
480   switch (isl_ast_expr_get_type(Expr)) {
481   case isl_ast_expr_error:
482     llvm_unreachable("Code generation error");
483   case isl_ast_expr_op:
484     return createOp(Expr);
485   case isl_ast_expr_id:
486     return createId(Expr);
487   case isl_ast_expr_int:
488     return createInt(Expr);
489   }
490 
491   llvm_unreachable("Unexpected enum value");
492 }
493