1 //===------ IslExprBuilder.cpp ----- Code generate isl AST expressions ----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "polly/CodeGen/IslExprBuilder.h"
13 
14 #include "polly/ScopInfo.h"
15 #include "polly/Support/GICHelper.h"
16 
17 #include "llvm/Analysis/ScalarEvolutionExpander.h"
18 #include "llvm/Support/Debug.h"
19 
20 using namespace llvm;
21 using namespace polly;
22 
23 Type *IslExprBuilder::getWidestType(Type *T1, Type *T2) {
24   assert(isa<IntegerType>(T1) && isa<IntegerType>(T2));
25 
26   if (T1->getPrimitiveSizeInBits() < T2->getPrimitiveSizeInBits())
27     return T2;
28   else
29     return T1;
30 }
31 
32 Value *IslExprBuilder::createOpUnary(__isl_take isl_ast_expr *Expr) {
33   assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_minus &&
34          "Unsupported unary operation");
35 
36   Value *V;
37   Type *MaxType = getType(Expr);
38 
39   V = create(isl_ast_expr_get_op_arg(Expr, 0));
40   MaxType = getWidestType(MaxType, V->getType());
41 
42   if (MaxType != V->getType())
43     V = Builder.CreateSExt(V, MaxType);
44 
45   isl_ast_expr_free(Expr);
46   return Builder.CreateNSWNeg(V);
47 }
48 
49 Value *IslExprBuilder::createOpNAry(__isl_take isl_ast_expr *Expr) {
50   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
51          "isl ast expression not of type isl_ast_op");
52   assert(isl_ast_expr_get_op_n_arg(Expr) >= 2 &&
53          "We need at least two operands in an n-ary operation");
54 
55   Value *V;
56 
57   V = create(isl_ast_expr_get_op_arg(Expr, 0));
58 
59   for (int i = 0; i < isl_ast_expr_get_op_n_arg(Expr); ++i) {
60     Value *OpV;
61     OpV = create(isl_ast_expr_get_op_arg(Expr, i));
62 
63     Type *Ty = getWidestType(V->getType(), OpV->getType());
64 
65     if (Ty != OpV->getType())
66       OpV = Builder.CreateSExt(OpV, Ty);
67 
68     if (Ty != V->getType())
69       V = Builder.CreateSExt(V, Ty);
70 
71     switch (isl_ast_expr_get_op_type(Expr)) {
72     default:
73       llvm_unreachable("This is no n-ary isl ast expression");
74 
75     case isl_ast_op_max: {
76       Value *Cmp = Builder.CreateICmpSGT(V, OpV);
77       V = Builder.CreateSelect(Cmp, V, OpV);
78       continue;
79     }
80     case isl_ast_op_min: {
81       Value *Cmp = Builder.CreateICmpSLT(V, OpV);
82       V = Builder.CreateSelect(Cmp, V, OpV);
83       continue;
84     }
85     }
86   }
87 
88   // TODO: We can truncate the result, if it fits into a smaller type. This can
89   // help in cases where we have larger operands (e.g. i67) but the result is
90   // known to fit into i64. Without the truncation, the larger i67 type may
91   // force all subsequent operations to be performed on a non-native type.
92   isl_ast_expr_free(Expr);
93   return V;
94 }
95 
96 Value *IslExprBuilder::createAccessAddress(isl_ast_expr *Expr) {
97   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
98          "isl ast expression not of type isl_ast_op");
99   assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_access &&
100          "not an access isl ast expression");
101   assert(isl_ast_expr_get_op_n_arg(Expr) >= 2 &&
102          "We need at least two operands to create a member access.");
103 
104   Value *Base, *IndexOp, *Access;
105   isl_ast_expr *BaseExpr;
106   isl_id *BaseId;
107 
108   BaseExpr = isl_ast_expr_get_op_arg(Expr, 0);
109   BaseId = isl_ast_expr_get_id(BaseExpr);
110   isl_ast_expr_free(BaseExpr);
111 
112   const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(BaseId);
113   Base = SAI->getBasePtr();
114   assert(Base->getType()->isPointerTy() && "Access base should be a pointer");
115   StringRef BaseName = Base->getName();
116 
117   if (Base->getType() != SAI->getType())
118     Base = Builder.CreateBitCast(Base, SAI->getType(),
119                                  "polly.access.cast." + BaseName);
120 
121   IndexOp = nullptr;
122   for (unsigned u = 1, e = isl_ast_expr_get_op_n_arg(Expr); u < e; u++) {
123     Value *NextIndex = create(isl_ast_expr_get_op_arg(Expr, u));
124     assert(NextIndex->getType()->isIntegerTy() &&
125            "Access index should be an integer");
126 
127     if (!IndexOp)
128       IndexOp = NextIndex;
129     else
130       IndexOp =
131           Builder.CreateAdd(IndexOp, NextIndex, "polly.access.add." + BaseName);
132 
133     // For every but the last dimension multiply the size, for the last
134     // dimension we can exit the loop.
135     if (u + 1 >= e)
136       break;
137 
138     const SCEV *DimSCEV = SAI->getDimensionSize(u - 1);
139     Value *DimSize = Expander.expandCodeFor(DimSCEV, DimSCEV->getType(),
140                                             Builder.GetInsertPoint());
141 
142     Type *Ty = getWidestType(DimSize->getType(), IndexOp->getType());
143 
144     if (Ty != IndexOp->getType())
145       IndexOp = Builder.CreateSExtOrTrunc(IndexOp, Ty,
146                                           "polly.access.sext." + BaseName);
147 
148     IndexOp =
149         Builder.CreateMul(IndexOp, DimSize, "polly.access.mul." + BaseName);
150   }
151 
152   Access = Builder.CreateGEP(Base, IndexOp, "polly.access." + BaseName);
153 
154   isl_ast_expr_free(Expr);
155   return Access;
156 }
157 
158 Value *IslExprBuilder::createOpAccess(isl_ast_expr *Expr) {
159   Value *Addr = createAccessAddress(Expr);
160   assert(Addr && "Could not create op access address");
161   return Builder.CreateLoad(Addr, Addr->getName() + ".load");
162 }
163 
164 Value *IslExprBuilder::createOpBin(__isl_take isl_ast_expr *Expr) {
165   Value *LHS, *RHS, *Res;
166   Type *MaxType;
167   isl_ast_op_type OpType;
168 
169   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
170          "isl ast expression not of type isl_ast_op");
171   assert(isl_ast_expr_get_op_n_arg(Expr) == 2 &&
172          "not a binary isl ast expression");
173 
174   OpType = isl_ast_expr_get_op_type(Expr);
175 
176   LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
177   RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
178 
179   MaxType = LHS->getType();
180   MaxType = getWidestType(MaxType, RHS->getType());
181 
182   // Take the result into account when calculating the widest type.
183   //
184   // For operations such as '+' the result may require a type larger than
185   // the type of the individual operands. For other operations such as '/', the
186   // result type cannot be larger than the type of the individual operand. isl
187   // does not calculate correct types for these operations and we consequently
188   // exclude those operations here.
189   switch (OpType) {
190   case isl_ast_op_pdiv_q:
191   case isl_ast_op_pdiv_r:
192   case isl_ast_op_div:
193   case isl_ast_op_fdiv_q:
194   case isl_ast_op_zdiv_r:
195     // Do nothing
196     break;
197   case isl_ast_op_add:
198   case isl_ast_op_sub:
199   case isl_ast_op_mul:
200     MaxType = getWidestType(MaxType, getType(Expr));
201     break;
202   default:
203     llvm_unreachable("This is no binary isl ast expression");
204   }
205 
206   if (MaxType != RHS->getType())
207     RHS = Builder.CreateSExt(RHS, MaxType);
208 
209   if (MaxType != LHS->getType())
210     LHS = Builder.CreateSExt(LHS, MaxType);
211 
212   switch (OpType) {
213   default:
214     llvm_unreachable("This is no binary isl ast expression");
215   case isl_ast_op_add:
216     Res = Builder.CreateNSWAdd(LHS, RHS);
217     break;
218   case isl_ast_op_sub:
219     Res = Builder.CreateNSWSub(LHS, RHS);
220     break;
221   case isl_ast_op_mul:
222     Res = Builder.CreateNSWMul(LHS, RHS);
223     break;
224   case isl_ast_op_div:
225   case isl_ast_op_pdiv_q: // Dividend is non-negative
226     Res = Builder.CreateSDiv(LHS, RHS);
227     break;
228   case isl_ast_op_fdiv_q: { // Round towards -infty
229     // TODO: Review code and check that this calculation does not yield
230     //       incorrect overflow in some bordercases.
231     //
232     // floord(n,d) ((n < 0) ? (n - d + 1) : n) / d
233     Value *One = ConstantInt::get(MaxType, 1);
234     Value *Zero = ConstantInt::get(MaxType, 0);
235     Value *Sum1 = Builder.CreateSub(LHS, RHS);
236     Value *Sum2 = Builder.CreateAdd(Sum1, One);
237     Value *isNegative = Builder.CreateICmpSLT(LHS, Zero);
238     Value *Dividend = Builder.CreateSelect(isNegative, Sum2, LHS);
239     Res = Builder.CreateSDiv(Dividend, RHS);
240     break;
241   }
242   case isl_ast_op_pdiv_r: // Dividend is non-negative
243   case isl_ast_op_zdiv_r: // Result only compared against zero
244     Res = Builder.CreateSRem(LHS, RHS);
245     break;
246   }
247 
248   // TODO: We can truncate the result, if it fits into a smaller type. This can
249   // help in cases where we have larger operands (e.g. i67) but the result is
250   // known to fit into i64. Without the truncation, the larger i67 type may
251   // force all subsequent operations to be performed on a non-native type.
252   isl_ast_expr_free(Expr);
253   return Res;
254 }
255 
256 Value *IslExprBuilder::createOpSelect(__isl_take isl_ast_expr *Expr) {
257   assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_select &&
258          "Unsupported unary isl ast expression");
259   Value *LHS, *RHS, *Cond;
260   Type *MaxType = getType(Expr);
261 
262   Cond = create(isl_ast_expr_get_op_arg(Expr, 0));
263   if (!Cond->getType()->isIntegerTy(1))
264     Cond = Builder.CreateIsNotNull(Cond);
265 
266   LHS = create(isl_ast_expr_get_op_arg(Expr, 1));
267   RHS = create(isl_ast_expr_get_op_arg(Expr, 2));
268 
269   MaxType = getWidestType(MaxType, LHS->getType());
270   MaxType = getWidestType(MaxType, RHS->getType());
271 
272   if (MaxType != RHS->getType())
273     RHS = Builder.CreateSExt(RHS, MaxType);
274 
275   if (MaxType != LHS->getType())
276     LHS = Builder.CreateSExt(LHS, MaxType);
277 
278   // TODO: Do we want to truncate the result?
279   isl_ast_expr_free(Expr);
280   return Builder.CreateSelect(Cond, LHS, RHS);
281 }
282 
283 Value *IslExprBuilder::createOpICmp(__isl_take isl_ast_expr *Expr) {
284   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
285          "Expected an isl_ast_expr_op expression");
286 
287   Value *LHS, *RHS, *Res;
288 
289   LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
290   RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
291 
292   bool IsPtrType =
293       LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy();
294 
295   if (LHS->getType() != RHS->getType()) {
296     if (IsPtrType) {
297       Type *I8PtrTy = Builder.getInt8PtrTy();
298       if (!LHS->getType()->isPointerTy())
299         LHS = Builder.CreateIntToPtr(LHS, I8PtrTy);
300       if (!RHS->getType()->isPointerTy())
301         RHS = Builder.CreateIntToPtr(RHS, I8PtrTy);
302       if (LHS->getType() != I8PtrTy)
303         LHS = Builder.CreateBitCast(LHS, I8PtrTy);
304       if (RHS->getType() != I8PtrTy)
305         RHS = Builder.CreateBitCast(RHS, I8PtrTy);
306     } else {
307       Type *MaxType = LHS->getType();
308       MaxType = getWidestType(MaxType, RHS->getType());
309 
310       if (MaxType != RHS->getType())
311         RHS = Builder.CreateSExt(RHS, MaxType);
312 
313       if (MaxType != LHS->getType())
314         LHS = Builder.CreateSExt(LHS, MaxType);
315     }
316   }
317 
318   isl_ast_op_type OpType = isl_ast_expr_get_op_type(Expr);
319   assert(OpType >= isl_ast_op_eq && OpType <= isl_ast_op_gt &&
320          "Unsupported ICmp isl ast expression");
321   assert(isl_ast_op_eq + 4 == isl_ast_op_gt &&
322          "Isl ast op type interface changed");
323 
324   CmpInst::Predicate Predicates[5][2] = {
325       {CmpInst::ICMP_EQ, CmpInst::ICMP_EQ},
326       {CmpInst::ICMP_SLE, CmpInst::ICMP_ULE},
327       {CmpInst::ICMP_SLT, CmpInst::ICMP_ULT},
328       {CmpInst::ICMP_SGE, CmpInst::ICMP_UGE},
329       {CmpInst::ICMP_SGT, CmpInst::ICMP_UGT},
330   };
331 
332   Res = Builder.CreateICmp(Predicates[OpType - isl_ast_op_eq][IsPtrType], LHS,
333                            RHS);
334 
335   isl_ast_expr_free(Expr);
336   return Res;
337 }
338 
339 Value *IslExprBuilder::createOpBoolean(__isl_take isl_ast_expr *Expr) {
340   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
341          "Expected an isl_ast_expr_op expression");
342 
343   Value *LHS, *RHS, *Res;
344   isl_ast_op_type OpType;
345 
346   OpType = isl_ast_expr_get_op_type(Expr);
347 
348   assert((OpType == isl_ast_op_and || OpType == isl_ast_op_or) &&
349          "Unsupported isl_ast_op_type");
350 
351   LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
352   RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
353 
354   // Even though the isl pretty printer prints the expressions as 'exp && exp'
355   // or 'exp || exp', we actually code generate the bitwise expressions
356   // 'exp & exp' or 'exp | exp'. This forces the evaluation of both branches,
357   // but it is, due to the use of i1 types, otherwise equivalent. The reason
358   // to go for bitwise operations is, that we assume the reduced control flow
359   // will outweight the overhead introduced by evaluating unneeded expressions.
360   // The isl code generation currently does not take advantage of the fact that
361   // the expression after an '||' or '&&' is in some cases not evaluated.
362   // Evaluating it anyways does not cause any undefined behaviour.
363   //
364   // TODO: Document in isl itself, that the unconditionally evaluating the
365   // second part of '||' or '&&' expressions is safe.
366   if (!LHS->getType()->isIntegerTy(1))
367     LHS = Builder.CreateIsNotNull(LHS);
368   if (!RHS->getType()->isIntegerTy(1))
369     RHS = Builder.CreateIsNotNull(RHS);
370 
371   switch (OpType) {
372   default:
373     llvm_unreachable("Unsupported boolean expression");
374   case isl_ast_op_and:
375     Res = Builder.CreateAnd(LHS, RHS);
376     break;
377   case isl_ast_op_or:
378     Res = Builder.CreateOr(LHS, RHS);
379     break;
380   }
381 
382   isl_ast_expr_free(Expr);
383   return Res;
384 }
385 
386 Value *IslExprBuilder::createOp(__isl_take isl_ast_expr *Expr) {
387   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
388          "Expression not of type isl_ast_expr_op");
389   switch (isl_ast_expr_get_op_type(Expr)) {
390   case isl_ast_op_error:
391   case isl_ast_op_cond:
392   case isl_ast_op_and_then:
393   case isl_ast_op_or_else:
394   case isl_ast_op_call:
395   case isl_ast_op_member:
396     llvm_unreachable("Unsupported isl ast expression");
397   case isl_ast_op_access:
398     return createOpAccess(Expr);
399   case isl_ast_op_max:
400   case isl_ast_op_min:
401     return createOpNAry(Expr);
402   case isl_ast_op_add:
403   case isl_ast_op_sub:
404   case isl_ast_op_mul:
405   case isl_ast_op_div:
406   case isl_ast_op_fdiv_q: // Round towards -infty
407   case isl_ast_op_pdiv_q: // Dividend is non-negative
408   case isl_ast_op_pdiv_r: // Dividend is non-negative
409   case isl_ast_op_zdiv_r: // Result only compared against zero
410     return createOpBin(Expr);
411   case isl_ast_op_minus:
412     return createOpUnary(Expr);
413   case isl_ast_op_select:
414     return createOpSelect(Expr);
415   case isl_ast_op_and:
416   case isl_ast_op_or:
417     return createOpBoolean(Expr);
418   case isl_ast_op_eq:
419   case isl_ast_op_le:
420   case isl_ast_op_lt:
421   case isl_ast_op_ge:
422   case isl_ast_op_gt:
423     return createOpICmp(Expr);
424   case isl_ast_op_address_of:
425     return createOpAddressOf(Expr);
426   }
427 
428   llvm_unreachable("Unsupported isl_ast_expr_op kind.");
429 }
430 
431 Value *IslExprBuilder::createOpAddressOf(__isl_take isl_ast_expr *Expr) {
432   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
433          "Expected an isl_ast_expr_op expression.");
434   assert(isl_ast_expr_get_op_n_arg(Expr) == 1 && "Address of should be unary.");
435 
436   isl_ast_expr *Op = isl_ast_expr_get_op_arg(Expr, 0);
437   assert(isl_ast_expr_get_type(Op) == isl_ast_expr_op &&
438          "Expected address of operator to be an isl_ast_expr_op expression.");
439   assert(isl_ast_expr_get_op_type(Op) == isl_ast_op_access &&
440          "Expected address of operator to be an access expression.");
441 
442   Value *V = createAccessAddress(Op);
443 
444   isl_ast_expr_free(Expr);
445 
446   return V;
447 }
448 
449 Value *IslExprBuilder::createId(__isl_take isl_ast_expr *Expr) {
450   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_id &&
451          "Expression not of type isl_ast_expr_ident");
452 
453   isl_id *Id;
454   Value *V;
455 
456   Id = isl_ast_expr_get_id(Expr);
457 
458   assert(IDToValue.count(Id) && "Identifier not found");
459 
460   V = IDToValue[Id];
461 
462   isl_id_free(Id);
463   isl_ast_expr_free(Expr);
464 
465   return V;
466 }
467 
468 IntegerType *IslExprBuilder::getType(__isl_keep isl_ast_expr *Expr) {
469   // XXX: We assume i64 is large enough. This is often true, but in general
470   //      incorrect. Also, on 32bit architectures, it would be beneficial to
471   //      use a smaller type. We can and should directly derive this information
472   //      during code generation.
473   return IntegerType::get(Builder.getContext(), 64);
474 }
475 
476 Value *IslExprBuilder::createInt(__isl_take isl_ast_expr *Expr) {
477   assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_int &&
478          "Expression not of type isl_ast_expr_int");
479   isl_val *Val;
480   Value *V;
481   APInt APValue;
482   IntegerType *T;
483 
484   Val = isl_ast_expr_get_val(Expr);
485   APValue = APIntFromVal(Val);
486   T = getType(Expr);
487   APValue = APValue.sextOrSelf(T->getBitWidth());
488   V = ConstantInt::get(T, APValue);
489 
490   isl_ast_expr_free(Expr);
491   return V;
492 }
493 
494 Value *IslExprBuilder::create(__isl_take isl_ast_expr *Expr) {
495   switch (isl_ast_expr_get_type(Expr)) {
496   case isl_ast_expr_error:
497     llvm_unreachable("Code generation error");
498   case isl_ast_expr_op:
499     return createOp(Expr);
500   case isl_ast_expr_id:
501     return createId(Expr);
502   case isl_ast_expr_int:
503     return createInt(Expr);
504   }
505 
506   llvm_unreachable("Unexpected enum value");
507 }
508