xref: /sqlite-3.40.0/ext/misc/decimal.c (revision e103a8de)
1 /*
2 ** 2020-06-22
3 **
4 ** The author disclaims copyright to this source code.  In place of
5 ** a legal notice, here is a blessing:
6 **
7 **    May you do good and not evil.
8 **    May you find forgiveness for yourself and forgive others.
9 **    May you share freely, never taking more than you give.
10 **
11 ******************************************************************************
12 **
13 ** Routines to implement arbitrary-precision decimal math.
14 **
15 ** The focus here is on simplicity and correctness, not performance.
16 */
17 #include "sqlite3ext.h"
18 SQLITE_EXTENSION_INIT1
19 #include <assert.h>
20 #include <string.h>
21 #include <ctype.h>
22 #include <stdlib.h>
23 
24 /* Mark a function parameter as unused, to suppress nuisance compiler
25 ** warnings. */
26 #ifndef UNUSED_PARAMETER
27 # define UNUSED_PARAMETER(X)  (void)(X)
28 #endif
29 
30 
31 /* A decimal object */
32 typedef struct Decimal Decimal;
33 struct Decimal {
34   char sign;        /* 0 for positive, 1 for negative */
35   char oom;         /* True if an OOM is encountered */
36   char isNull;      /* True if holds a NULL rather than a number */
37   char isInit;      /* True upon initialization */
38   int nDigit;       /* Total number of digits */
39   int nFrac;        /* Number of digits to the right of the decimal point */
40   signed char *a;   /* Array of digits.  Most significant first. */
41 };
42 
43 /*
44 ** Release memory held by a Decimal, but do not free the object itself.
45 */
decimal_clear(Decimal * p)46 static void decimal_clear(Decimal *p){
47   sqlite3_free(p->a);
48 }
49 
50 /*
51 ** Destroy a Decimal object
52 */
decimal_free(Decimal * p)53 static void decimal_free(Decimal *p){
54   if( p ){
55     decimal_clear(p);
56     sqlite3_free(p);
57   }
58 }
59 
60 /*
61 ** Allocate a new Decimal object.  Initialize it to the number given
62 ** by the input string.
63 */
decimal_new(sqlite3_context * pCtx,sqlite3_value * pIn,int nAlt,const unsigned char * zAlt)64 static Decimal *decimal_new(
65   sqlite3_context *pCtx,
66   sqlite3_value *pIn,
67   int nAlt,
68   const unsigned char *zAlt
69 ){
70   Decimal *p;
71   int n, i;
72   const unsigned char *zIn;
73   int iExp = 0;
74   p = sqlite3_malloc( sizeof(*p) );
75   if( p==0 ) goto new_no_mem;
76   p->sign = 0;
77   p->oom = 0;
78   p->isInit = 1;
79   p->isNull = 0;
80   p->nDigit = 0;
81   p->nFrac = 0;
82   if( zAlt ){
83     n = nAlt,
84     zIn = zAlt;
85   }else{
86     if( sqlite3_value_type(pIn)==SQLITE_NULL ){
87       p->a = 0;
88       p->isNull = 1;
89       return p;
90     }
91     n = sqlite3_value_bytes(pIn);
92     zIn = sqlite3_value_text(pIn);
93   }
94   p->a = sqlite3_malloc64( n+1 );
95   if( p->a==0 ) goto new_no_mem;
96   for(i=0; isspace(zIn[i]); i++){}
97   if( zIn[i]=='-' ){
98     p->sign = 1;
99     i++;
100   }else if( zIn[i]=='+' ){
101     i++;
102   }
103   while( i<n && zIn[i]=='0' ) i++;
104   while( i<n ){
105     char c = zIn[i];
106     if( c>='0' && c<='9' ){
107       p->a[p->nDigit++] = c - '0';
108     }else if( c=='.' ){
109       p->nFrac = p->nDigit + 1;
110     }else if( c=='e' || c=='E' ){
111       int j = i+1;
112       int neg = 0;
113       if( j>=n ) break;
114       if( zIn[j]=='-' ){
115         neg = 1;
116         j++;
117       }else if( zIn[j]=='+' ){
118         j++;
119       }
120       while( j<n && iExp<1000000 ){
121         if( zIn[j]>='0' && zIn[j]<='9' ){
122           iExp = iExp*10 + zIn[j] - '0';
123         }
124         j++;
125       }
126       if( neg ) iExp = -iExp;
127       break;
128     }
129     i++;
130   }
131   if( p->nFrac ){
132     p->nFrac = p->nDigit - (p->nFrac - 1);
133   }
134   if( iExp>0 ){
135     if( p->nFrac>0 ){
136       if( iExp<=p->nFrac ){
137         p->nFrac -= iExp;
138         iExp = 0;
139       }else{
140         iExp -= p->nFrac;
141         p->nFrac = 0;
142       }
143     }
144     if( iExp>0 ){
145       p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
146       if( p->a==0 ) goto new_no_mem;
147       memset(p->a+p->nDigit, 0, iExp);
148       p->nDigit += iExp;
149     }
150   }else if( iExp<0 ){
151     int nExtra;
152     iExp = -iExp;
153     nExtra = p->nDigit - p->nFrac - 1;
154     if( nExtra ){
155       if( nExtra>=iExp ){
156         p->nFrac += iExp;
157         iExp  = 0;
158       }else{
159         iExp -= nExtra;
160         p->nFrac = p->nDigit - 1;
161       }
162     }
163     if( iExp>0 ){
164       p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
165       if( p->a==0 ) goto new_no_mem;
166       memmove(p->a+iExp, p->a, p->nDigit);
167       memset(p->a, 0, iExp);
168       p->nDigit += iExp;
169       p->nFrac += iExp;
170     }
171   }
172   return p;
173 
174 new_no_mem:
175   if( pCtx ) sqlite3_result_error_nomem(pCtx);
176   sqlite3_free(p);
177   return 0;
178 }
179 
180 /*
181 ** Make the given Decimal the result.
182 */
decimal_result(sqlite3_context * pCtx,Decimal * p)183 static void decimal_result(sqlite3_context *pCtx, Decimal *p){
184   char *z;
185   int i, j;
186   int n;
187   if( p==0 || p->oom ){
188     sqlite3_result_error_nomem(pCtx);
189     return;
190   }
191   if( p->isNull ){
192     sqlite3_result_null(pCtx);
193     return;
194   }
195   z = sqlite3_malloc( p->nDigit+4 );
196   if( z==0 ){
197     sqlite3_result_error_nomem(pCtx);
198     return;
199   }
200   i = 0;
201   if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){
202     p->sign = 0;
203   }
204   if( p->sign ){
205     z[0] = '-';
206     i = 1;
207   }
208   n = p->nDigit - p->nFrac;
209   if( n<=0 ){
210     z[i++] = '0';
211   }
212   j = 0;
213   while( n>1 && p->a[j]==0 ){
214     j++;
215     n--;
216   }
217   while( n>0  ){
218     z[i++] = p->a[j] + '0';
219     j++;
220     n--;
221   }
222   if( p->nFrac ){
223     z[i++] = '.';
224     do{
225       z[i++] = p->a[j] + '0';
226       j++;
227     }while( j<p->nDigit );
228   }
229   z[i] = 0;
230   sqlite3_result_text(pCtx, z, i, sqlite3_free);
231 }
232 
233 /*
234 ** SQL Function:   decimal(X)
235 **
236 ** Convert input X into decimal and then back into text
237 */
decimalFunc(sqlite3_context * context,int argc,sqlite3_value ** argv)238 static void decimalFunc(
239   sqlite3_context *context,
240   int argc,
241   sqlite3_value **argv
242 ){
243   Decimal *p = decimal_new(context, argv[0], 0, 0);
244   UNUSED_PARAMETER(argc);
245   decimal_result(context, p);
246   decimal_free(p);
247 }
248 
249 /*
250 ** Compare to Decimal objects.  Return negative, 0, or positive if the
251 ** first object is less than, equal to, or greater than the second.
252 **
253 ** Preconditions for this routine:
254 **
255 **    pA!=0
256 **    pA->isNull==0
257 **    pB!=0
258 **    pB->isNull==0
259 */
decimal_cmp(const Decimal * pA,const Decimal * pB)260 static int decimal_cmp(const Decimal *pA, const Decimal *pB){
261   int nASig, nBSig, rc, n;
262   if( pA->sign!=pB->sign ){
263     return pA->sign ? -1 : +1;
264   }
265   if( pA->sign ){
266     const Decimal *pTemp = pA;
267     pA = pB;
268     pB = pTemp;
269   }
270   nASig = pA->nDigit - pA->nFrac;
271   nBSig = pB->nDigit - pB->nFrac;
272   if( nASig!=nBSig ){
273     return nASig - nBSig;
274   }
275   n = pA->nDigit;
276   if( n>pB->nDigit ) n = pB->nDigit;
277   rc = memcmp(pA->a, pB->a, n);
278   if( rc==0 ){
279     rc = pA->nDigit - pB->nDigit;
280   }
281   return rc;
282 }
283 
284 /*
285 ** SQL Function:   decimal_cmp(X, Y)
286 **
287 ** Return negative, zero, or positive if X is less then, equal to, or
288 ** greater than Y.
289 */
decimalCmpFunc(sqlite3_context * context,int argc,sqlite3_value ** argv)290 static void decimalCmpFunc(
291   sqlite3_context *context,
292   int argc,
293   sqlite3_value **argv
294 ){
295   Decimal *pA = 0, *pB = 0;
296   int rc;
297 
298   UNUSED_PARAMETER(argc);
299   pA = decimal_new(context, argv[0], 0, 0);
300   if( pA==0 || pA->isNull ) goto cmp_done;
301   pB = decimal_new(context, argv[1], 0, 0);
302   if( pB==0 || pB->isNull ) goto cmp_done;
303   rc = decimal_cmp(pA, pB);
304   if( rc<0 ) rc = -1;
305   else if( rc>0 ) rc = +1;
306   sqlite3_result_int(context, rc);
307 cmp_done:
308   decimal_free(pA);
309   decimal_free(pB);
310 }
311 
312 /*
313 ** Expand the Decimal so that it has a least nDigit digits and nFrac
314 ** digits to the right of the decimal point.
315 */
decimal_expand(Decimal * p,int nDigit,int nFrac)316 static void decimal_expand(Decimal *p, int nDigit, int nFrac){
317   int nAddSig;
318   int nAddFrac;
319   if( p==0 ) return;
320   nAddFrac = nFrac - p->nFrac;
321   nAddSig = (nDigit - p->nDigit) - nAddFrac;
322   if( nAddFrac==0 && nAddSig==0 ) return;
323   p->a = sqlite3_realloc64(p->a, nDigit+1);
324   if( p->a==0 ){
325     p->oom = 1;
326     return;
327   }
328   if( nAddSig ){
329     memmove(p->a+nAddSig, p->a, p->nDigit);
330     memset(p->a, 0, nAddSig);
331     p->nDigit += nAddSig;
332   }
333   if( nAddFrac ){
334     memset(p->a+p->nDigit, 0, nAddFrac);
335     p->nDigit += nAddFrac;
336     p->nFrac += nAddFrac;
337   }
338 }
339 
340 /*
341 ** Add the value pB into pA.
342 **
343 ** Both pA and pB might become denormalized by this routine.
344 */
decimal_add(Decimal * pA,Decimal * pB)345 static void decimal_add(Decimal *pA, Decimal *pB){
346   int nSig, nFrac, nDigit;
347   int i, rc;
348   if( pA==0 ){
349     return;
350   }
351   if( pA->oom || pB==0 || pB->oom ){
352     pA->oom = 1;
353     return;
354   }
355   if( pA->isNull || pB->isNull ){
356     pA->isNull = 1;
357     return;
358   }
359   nSig = pA->nDigit - pA->nFrac;
360   if( nSig && pA->a[0]==0 ) nSig--;
361   if( nSig<pB->nDigit-pB->nFrac ){
362     nSig = pB->nDigit - pB->nFrac;
363   }
364   nFrac = pA->nFrac;
365   if( nFrac<pB->nFrac ) nFrac = pB->nFrac;
366   nDigit = nSig + nFrac + 1;
367   decimal_expand(pA, nDigit, nFrac);
368   decimal_expand(pB, nDigit, nFrac);
369   if( pA->oom || pB->oom ){
370     pA->oom = 1;
371   }else{
372     if( pA->sign==pB->sign ){
373       int carry = 0;
374       for(i=nDigit-1; i>=0; i--){
375         int x = pA->a[i] + pB->a[i] + carry;
376         if( x>=10 ){
377           carry = 1;
378           pA->a[i] = x - 10;
379         }else{
380           carry = 0;
381           pA->a[i] = x;
382         }
383       }
384     }else{
385       signed char *aA, *aB;
386       int borrow = 0;
387       rc = memcmp(pA->a, pB->a, nDigit);
388       if( rc<0 ){
389         aA = pB->a;
390         aB = pA->a;
391         pA->sign = !pA->sign;
392       }else{
393         aA = pA->a;
394         aB = pB->a;
395       }
396       for(i=nDigit-1; i>=0; i--){
397         int x = aA[i] - aB[i] - borrow;
398         if( x<0 ){
399           pA->a[i] = x+10;
400           borrow = 1;
401         }else{
402           pA->a[i] = x;
403           borrow = 0;
404         }
405       }
406     }
407   }
408 }
409 
410 /*
411 ** Compare text in decimal order.
412 */
decimalCollFunc(void * notUsed,int nKey1,const void * pKey1,int nKey2,const void * pKey2)413 static int decimalCollFunc(
414   void *notUsed,
415   int nKey1, const void *pKey1,
416   int nKey2, const void *pKey2
417 ){
418   const unsigned char *zA = (const unsigned char*)pKey1;
419   const unsigned char *zB = (const unsigned char*)pKey2;
420   Decimal *pA = decimal_new(0, 0, nKey1, zA);
421   Decimal *pB = decimal_new(0, 0, nKey2, zB);
422   int rc;
423   UNUSED_PARAMETER(notUsed);
424   if( pA==0 || pB==0 ){
425     rc = 0;
426   }else{
427     rc = decimal_cmp(pA, pB);
428   }
429   decimal_free(pA);
430   decimal_free(pB);
431   return rc;
432 }
433 
434 
435 /*
436 ** SQL Function:   decimal_add(X, Y)
437 **                 decimal_sub(X, Y)
438 **
439 ** Return the sum or difference of X and Y.
440 */
decimalAddFunc(sqlite3_context * context,int argc,sqlite3_value ** argv)441 static void decimalAddFunc(
442   sqlite3_context *context,
443   int argc,
444   sqlite3_value **argv
445 ){
446   Decimal *pA = decimal_new(context, argv[0], 0, 0);
447   Decimal *pB = decimal_new(context, argv[1], 0, 0);
448   UNUSED_PARAMETER(argc);
449   decimal_add(pA, pB);
450   decimal_result(context, pA);
451   decimal_free(pA);
452   decimal_free(pB);
453 }
decimalSubFunc(sqlite3_context * context,int argc,sqlite3_value ** argv)454 static void decimalSubFunc(
455   sqlite3_context *context,
456   int argc,
457   sqlite3_value **argv
458 ){
459   Decimal *pA = decimal_new(context, argv[0], 0, 0);
460   Decimal *pB = decimal_new(context, argv[1], 0, 0);
461   UNUSED_PARAMETER(argc);
462   if( pB ){
463     pB->sign = !pB->sign;
464     decimal_add(pA, pB);
465     decimal_result(context, pA);
466   }
467   decimal_free(pA);
468   decimal_free(pB);
469 }
470 
471 /* Aggregate funcion:   decimal_sum(X)
472 **
473 ** Works like sum() except that it uses decimal arithmetic for unlimited
474 ** precision.
475 */
decimalSumStep(sqlite3_context * context,int argc,sqlite3_value ** argv)476 static void decimalSumStep(
477   sqlite3_context *context,
478   int argc,
479   sqlite3_value **argv
480 ){
481   Decimal *p;
482   Decimal *pArg;
483   UNUSED_PARAMETER(argc);
484   p = sqlite3_aggregate_context(context, sizeof(*p));
485   if( p==0 ) return;
486   if( !p->isInit ){
487     p->isInit = 1;
488     p->a = sqlite3_malloc(2);
489     if( p->a==0 ){
490       p->oom = 1;
491     }else{
492       p->a[0] = 0;
493     }
494     p->nDigit = 1;
495     p->nFrac = 0;
496   }
497   if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
498   pArg = decimal_new(context, argv[0], 0, 0);
499   decimal_add(p, pArg);
500   decimal_free(pArg);
501 }
decimalSumInverse(sqlite3_context * context,int argc,sqlite3_value ** argv)502 static void decimalSumInverse(
503   sqlite3_context *context,
504   int argc,
505   sqlite3_value **argv
506 ){
507   Decimal *p;
508   Decimal *pArg;
509   UNUSED_PARAMETER(argc);
510   p = sqlite3_aggregate_context(context, sizeof(*p));
511   if( p==0 ) return;
512   if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
513   pArg = decimal_new(context, argv[0], 0, 0);
514   if( pArg ) pArg->sign = !pArg->sign;
515   decimal_add(p, pArg);
516   decimal_free(pArg);
517 }
decimalSumValue(sqlite3_context * context)518 static void decimalSumValue(sqlite3_context *context){
519   Decimal *p = sqlite3_aggregate_context(context, 0);
520   if( p==0 ) return;
521   decimal_result(context, p);
522 }
decimalSumFinalize(sqlite3_context * context)523 static void decimalSumFinalize(sqlite3_context *context){
524   Decimal *p = sqlite3_aggregate_context(context, 0);
525   if( p==0 ) return;
526   decimal_result(context, p);
527   decimal_clear(p);
528 }
529 
530 /*
531 ** SQL Function:   decimal_mul(X, Y)
532 **
533 ** Return the product of X and Y.
534 **
535 ** All significant digits after the decimal point are retained.
536 ** Trailing zeros after the decimal point are omitted as long as
537 ** the number of digits after the decimal point is no less than
538 ** either the number of digits in either input.
539 */
decimalMulFunc(sqlite3_context * context,int argc,sqlite3_value ** argv)540 static void decimalMulFunc(
541   sqlite3_context *context,
542   int argc,
543   sqlite3_value **argv
544 ){
545   Decimal *pA = decimal_new(context, argv[0], 0, 0);
546   Decimal *pB = decimal_new(context, argv[1], 0, 0);
547   signed char *acc = 0;
548   int i, j, k;
549   int minFrac;
550   UNUSED_PARAMETER(argc);
551   if( pA==0 || pA->oom || pA->isNull
552    || pB==0 || pB->oom || pB->isNull
553   ){
554     goto mul_end;
555   }
556   acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 );
557   if( acc==0 ){
558     sqlite3_result_error_nomem(context);
559     goto mul_end;
560   }
561   memset(acc, 0, pA->nDigit + pB->nDigit + 2);
562   minFrac = pA->nFrac;
563   if( pB->nFrac<minFrac ) minFrac = pB->nFrac;
564   for(i=pA->nDigit-1; i>=0; i--){
565     signed char f = pA->a[i];
566     int carry = 0, x;
567     for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){
568       x = acc[k] + f*pB->a[j] + carry;
569       acc[k] = x%10;
570       carry = x/10;
571     }
572     x = acc[k] + carry;
573     acc[k] = x%10;
574     acc[k-1] += x/10;
575   }
576   sqlite3_free(pA->a);
577   pA->a = acc;
578   acc = 0;
579   pA->nDigit += pB->nDigit + 2;
580   pA->nFrac += pB->nFrac;
581   pA->sign ^= pB->sign;
582   while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){
583     pA->nFrac--;
584     pA->nDigit--;
585   }
586   decimal_result(context, pA);
587 
588 mul_end:
589   sqlite3_free(acc);
590   decimal_free(pA);
591   decimal_free(pB);
592 }
593 
594 #ifdef _WIN32
595 __declspec(dllexport)
596 #endif
sqlite3_decimal_init(sqlite3 * db,char ** pzErrMsg,const sqlite3_api_routines * pApi)597 int sqlite3_decimal_init(
598   sqlite3 *db,
599   char **pzErrMsg,
600   const sqlite3_api_routines *pApi
601 ){
602   int rc = SQLITE_OK;
603   static const struct {
604     const char *zFuncName;
605     int nArg;
606     void (*xFunc)(sqlite3_context*,int,sqlite3_value**);
607   } aFunc[] = {
608     { "decimal",       1,   decimalFunc        },
609     { "decimal_cmp",   2,   decimalCmpFunc     },
610     { "decimal_add",   2,   decimalAddFunc     },
611     { "decimal_sub",   2,   decimalSubFunc     },
612     { "decimal_mul",   2,   decimalMulFunc     },
613   };
614   unsigned int i;
615   (void)pzErrMsg;  /* Unused parameter */
616 
617   SQLITE_EXTENSION_INIT2(pApi);
618 
619   for(i=0; i<sizeof(aFunc)/sizeof(aFunc[0]) && rc==SQLITE_OK; i++){
620     rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg,
621                    SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC,
622                    0, aFunc[i].xFunc, 0, 0);
623   }
624   if( rc==SQLITE_OK ){
625     rc = sqlite3_create_window_function(db, "decimal_sum", 1,
626                    SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0,
627                    decimalSumStep, decimalSumFinalize,
628                    decimalSumValue, decimalSumInverse, 0);
629   }
630   if( rc==SQLITE_OK ){
631     rc = sqlite3_create_collation(db, "decimal", SQLITE_UTF8,
632                                   0, decimalCollFunc);
633   }
634   return rc;
635 }
636