xref: /sqlite-3.40.0/ext/misc/decimal.c (revision aeb4e6ee)
1 /*
2 ** 2020-06-22
3 **
4 ** The author disclaims copyright to this source code.  In place of
5 ** a legal notice, here is a blessing:
6 **
7 **    May you do good and not evil.
8 **    May you find forgiveness for yourself and forgive others.
9 **    May you share freely, never taking more than you give.
10 **
11 ******************************************************************************
12 **
13 ** Routines to implement arbitrary-precision decimal math.
14 **
15 ** The focus here is on simplicity and correctness, not performance.
16 */
17 #include "sqlite3ext.h"
18 SQLITE_EXTENSION_INIT1
19 #include <assert.h>
20 #include <string.h>
21 #include <ctype.h>
22 #include <stdlib.h>
23 
24 /* Mark a function parameter as unused, to suppress nuisance compiler
25 ** warnings. */
26 #ifndef UNUSED_PARAMETER
27 # define UNUSED_PARAMETER(X)  (void)(X)
28 #endif
29 
30 
31 /* A decimal object */
32 typedef struct Decimal Decimal;
33 struct Decimal {
34   char sign;        /* 0 for positive, 1 for negative */
35   char oom;         /* True if an OOM is encountered */
36   char isNull;      /* True if holds a NULL rather than a number */
37   char isInit;      /* True upon initialization */
38   int nDigit;       /* Total number of digits */
39   int nFrac;        /* Number of digits to the right of the decimal point */
40   signed char *a;   /* Array of digits.  Most significant first. */
41 };
42 
43 /*
44 ** Release memory held by a Decimal, but do not free the object itself.
45 */
46 static void decimal_clear(Decimal *p){
47   sqlite3_free(p->a);
48 }
49 
50 /*
51 ** Destroy a Decimal object
52 */
53 static void decimal_free(Decimal *p){
54   if( p ){
55     decimal_clear(p);
56     sqlite3_free(p);
57   }
58 }
59 
60 /*
61 ** Allocate a new Decimal object.  Initialize it to the number given
62 ** by the input string.
63 */
64 static Decimal *decimal_new(
65   sqlite3_context *pCtx,
66   sqlite3_value *pIn,
67   int nAlt,
68   const unsigned char *zAlt
69 ){
70   Decimal *p;
71   int n, i;
72   const unsigned char *zIn;
73   int iExp = 0;
74   p = sqlite3_malloc( sizeof(*p) );
75   if( p==0 ) goto new_no_mem;
76   p->sign = 0;
77   p->oom = 0;
78   p->isInit = 1;
79   p->isNull = 0;
80   p->nDigit = 0;
81   p->nFrac = 0;
82   if( zAlt ){
83     n = nAlt,
84     zIn = zAlt;
85   }else{
86     if( sqlite3_value_type(pIn)==SQLITE_NULL ){
87       p->a = 0;
88       p->isNull = 1;
89       return p;
90     }
91     n = sqlite3_value_bytes(pIn);
92     zIn = sqlite3_value_text(pIn);
93   }
94   p->a = sqlite3_malloc64( n+1 );
95   if( p->a==0 ) goto new_no_mem;
96   for(i=0; isspace(zIn[i]); i++){}
97   if( zIn[i]=='-' ){
98     p->sign = 1;
99     i++;
100   }else if( zIn[i]=='+' ){
101     i++;
102   }
103   while( i<n && zIn[i]=='0' ) i++;
104   while( i<n ){
105     char c = zIn[i];
106     if( c>='0' && c<='9' ){
107       p->a[p->nDigit++] = c - '0';
108     }else if( c=='.' ){
109       p->nFrac = p->nDigit + 1;
110     }else if( c=='e' || c=='E' ){
111       int j = i+1;
112       int neg = 0;
113       if( j>=n ) break;
114       if( zIn[j]=='-' ){
115         neg = 1;
116         j++;
117       }else if( zIn[j]=='+' ){
118         j++;
119       }
120       while( j<n && iExp<1000000 ){
121         if( zIn[j]>='0' && zIn[j]<='9' ){
122           iExp = iExp*10 + zIn[j] - '0';
123         }
124         j++;
125       }
126       if( neg ) iExp = -iExp;
127       break;
128     }
129     i++;
130   }
131   if( p->nFrac ){
132     p->nFrac = p->nDigit - (p->nFrac - 1);
133   }
134   if( iExp>0 ){
135     if( p->nFrac>0 ){
136       if( iExp<=p->nFrac ){
137         p->nFrac -= iExp;
138         iExp = 0;
139       }else{
140         iExp -= p->nFrac;
141         p->nFrac = 0;
142       }
143     }
144     if( iExp>0 ){
145       p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
146       if( p->a==0 ) goto new_no_mem;
147       memset(p->a+p->nDigit, 0, iExp);
148       p->nDigit += iExp;
149     }
150   }else if( iExp<0 ){
151     int nExtra;
152     iExp = -iExp;
153     nExtra = p->nDigit - p->nFrac - 1;
154     if( nExtra ){
155       if( nExtra>=iExp ){
156         p->nFrac += iExp;
157         iExp  = 0;
158       }else{
159         iExp -= nExtra;
160         p->nFrac = p->nDigit - 1;
161       }
162     }
163     if( iExp>0 ){
164       p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 );
165       if( p->a==0 ) goto new_no_mem;
166       memmove(p->a+iExp, p->a, p->nDigit);
167       memset(p->a, 0, iExp);
168       p->nDigit += iExp;
169       p->nFrac += iExp;
170     }
171   }
172   return p;
173 
174 new_no_mem:
175   if( pCtx ) sqlite3_result_error_nomem(pCtx);
176   sqlite3_free(p);
177   return 0;
178 }
179 
180 /*
181 ** Make the given Decimal the result.
182 */
183 static void decimal_result(sqlite3_context *pCtx, Decimal *p){
184   char *z;
185   int i, j;
186   int n;
187   if( p==0 || p->oom ){
188     sqlite3_result_error_nomem(pCtx);
189     return;
190   }
191   if( p->isNull ){
192     sqlite3_result_null(pCtx);
193     return;
194   }
195   z = sqlite3_malloc( p->nDigit+4 );
196   if( z==0 ){
197     sqlite3_result_error_nomem(pCtx);
198     return;
199   }
200   i = 0;
201   if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){
202     p->sign = 0;
203   }
204   if( p->sign ){
205     z[0] = '-';
206     i = 1;
207   }
208   n = p->nDigit - p->nFrac;
209   if( n<=0 ){
210     z[i++] = '0';
211   }
212   j = 0;
213   while( n>1 && p->a[j]==0 ){
214     j++;
215     n--;
216   }
217   while( n>0  ){
218     z[i++] = p->a[j] + '0';
219     j++;
220     n--;
221   }
222   if( p->nFrac ){
223     z[i++] = '.';
224     do{
225       z[i++] = p->a[j] + '0';
226       j++;
227     }while( j<p->nDigit );
228   }
229   z[i] = 0;
230   sqlite3_result_text(pCtx, z, i, sqlite3_free);
231 }
232 
233 /*
234 ** SQL Function:   decimal(X)
235 **
236 ** Convert input X into decimal and then back into text
237 */
238 static void decimalFunc(
239   sqlite3_context *context,
240   int argc,
241   sqlite3_value **argv
242 ){
243   Decimal *p = decimal_new(context, argv[0], 0, 0);
244   UNUSED_PARAMETER(argc);
245   decimal_result(context, p);
246   decimal_free(p);
247 }
248 
249 /*
250 ** Compare to Decimal objects.  Return negative, 0, or positive if the
251 ** first object is less than, equal to, or greater than the second.
252 **
253 ** Preconditions for this routine:
254 **
255 **    pA!=0
256 **    pA->isNull==0
257 **    pB!=0
258 **    pB->isNull==0
259 */
260 static int decimal_cmp(const Decimal *pA, const Decimal *pB){
261   int nASig, nBSig, rc, n;
262   if( pA->sign!=pB->sign ){
263     return pA->sign ? -1 : +1;
264   }
265   if( pA->sign ){
266     const Decimal *pTemp = pA;
267     pA = pB;
268     pB = pTemp;
269   }
270   nASig = pA->nDigit - pA->nFrac;
271   nBSig = pB->nDigit - pB->nFrac;
272   if( nASig!=nBSig ){
273     return nASig - nBSig;
274   }
275   n = pA->nDigit;
276   if( n>pB->nDigit ) n = pB->nDigit;
277   rc = memcmp(pA->a, pB->a, n);
278   if( rc==0 ){
279     rc = pA->nDigit - pB->nDigit;
280   }
281   return rc;
282 }
283 
284 /*
285 ** SQL Function:   decimal_cmp(X, Y)
286 **
287 ** Return negative, zero, or positive if X is less then, equal to, or
288 ** greater than Y.
289 */
290 static void decimalCmpFunc(
291   sqlite3_context *context,
292   int argc,
293   sqlite3_value **argv
294 ){
295   Decimal *pA = 0, *pB = 0;
296   int rc;
297 
298   UNUSED_PARAMETER(argc);
299   pA = decimal_new(context, argv[0], 0, 0);
300   if( pA==0 || pA->isNull ) goto cmp_done;
301   pB = decimal_new(context, argv[1], 0, 0);
302   if( pB==0 || pB->isNull ) goto cmp_done;
303   rc = decimal_cmp(pA, pB);
304   if( rc<0 ) rc = -1;
305   else if( rc>0 ) rc = +1;
306   sqlite3_result_int(context, rc);
307 cmp_done:
308   decimal_free(pA);
309   decimal_free(pB);
310 }
311 
312 /*
313 ** Expand the Decimal so that it has a least nDigit digits and nFrac
314 ** digits to the right of the decimal point.
315 */
316 static void decimal_expand(Decimal *p, int nDigit, int nFrac){
317   int nAddSig;
318   int nAddFrac;
319   if( p==0 ) return;
320   nAddFrac = nFrac - p->nFrac;
321   nAddSig = (nDigit - p->nDigit) - nAddFrac;
322   if( nAddFrac==0 && nAddSig==0 ) return;
323   p->a = sqlite3_realloc64(p->a, nDigit+1);
324   if( p->a==0 ){
325     p->oom = 1;
326     return;
327   }
328   if( nAddSig ){
329     memmove(p->a+nAddSig, p->a, p->nDigit);
330     memset(p->a, 0, nAddSig);
331     p->nDigit += nAddSig;
332   }
333   if( nAddFrac ){
334     memset(p->a+p->nDigit, 0, nAddFrac);
335     p->nDigit += nAddFrac;
336     p->nFrac += nAddFrac;
337   }
338 }
339 
340 /*
341 ** Add the value pB into pA.
342 **
343 ** Both pA and pB might become denormalized by this routine.
344 */
345 static void decimal_add(Decimal *pA, Decimal *pB){
346   int nSig, nFrac, nDigit;
347   int i, rc;
348   if( pA==0 ){
349     return;
350   }
351   if( pA->oom || pB==0 || pB->oom ){
352     pA->oom = 1;
353     return;
354   }
355   if( pA->isNull || pB->isNull ){
356     pA->isNull = 1;
357     return;
358   }
359   nSig = pA->nDigit - pA->nFrac;
360   if( nSig && pA->a[0]==0 ) nSig--;
361   if( nSig<pB->nDigit-pB->nFrac ){
362     nSig = pB->nDigit - pB->nFrac;
363   }
364   nFrac = pA->nFrac;
365   if( nFrac<pB->nFrac ) nFrac = pB->nFrac;
366   nDigit = nSig + nFrac + 1;
367   decimal_expand(pA, nDigit, nFrac);
368   decimal_expand(pB, nDigit, nFrac);
369   if( pA->oom || pB->oom ){
370     pA->oom = 1;
371   }else{
372     if( pA->sign==pB->sign ){
373       int carry = 0;
374       for(i=nDigit-1; i>=0; i--){
375         int x = pA->a[i] + pB->a[i] + carry;
376         if( x>=10 ){
377           carry = 1;
378           pA->a[i] = x - 10;
379         }else{
380           carry = 0;
381           pA->a[i] = x;
382         }
383       }
384     }else{
385       signed char *aA, *aB;
386       int borrow = 0;
387       rc = memcmp(pA->a, pB->a, nDigit);
388       if( rc<0 ){
389         aA = pB->a;
390         aB = pA->a;
391         pA->sign = !pA->sign;
392       }else{
393         aA = pA->a;
394         aB = pB->a;
395       }
396       for(i=nDigit-1; i>=0; i--){
397         int x = aA[i] - aB[i] - borrow;
398         if( x<0 ){
399           pA->a[i] = x+10;
400           borrow = 1;
401         }else{
402           pA->a[i] = x;
403           borrow = 0;
404         }
405       }
406     }
407   }
408 }
409 
410 /*
411 ** Compare text in decimal order.
412 */
413 static int decimalCollFunc(
414   void *notUsed,
415   int nKey1, const void *pKey1,
416   int nKey2, const void *pKey2
417 ){
418   const unsigned char *zA = (const unsigned char*)pKey1;
419   const unsigned char *zB = (const unsigned char*)pKey2;
420   Decimal *pA = decimal_new(0, 0, nKey1, zA);
421   Decimal *pB = decimal_new(0, 0, nKey2, zB);
422   int rc;
423   UNUSED_PARAMETER(notUsed);
424   if( pA==0 || pB==0 ){
425     rc = 0;
426   }else{
427     rc = decimal_cmp(pA, pB);
428   }
429   decimal_free(pA);
430   decimal_free(pB);
431   return rc;
432 }
433 
434 
435 /*
436 ** SQL Function:   decimal_add(X, Y)
437 **                 decimal_sub(X, Y)
438 **
439 ** Return the sum or difference of X and Y.
440 */
441 static void decimalAddFunc(
442   sqlite3_context *context,
443   int argc,
444   sqlite3_value **argv
445 ){
446   Decimal *pA = decimal_new(context, argv[0], 0, 0);
447   Decimal *pB = decimal_new(context, argv[1], 0, 0);
448   UNUSED_PARAMETER(argc);
449   decimal_add(pA, pB);
450   decimal_result(context, pA);
451   decimal_free(pA);
452   decimal_free(pB);
453 }
454 static void decimalSubFunc(
455   sqlite3_context *context,
456   int argc,
457   sqlite3_value **argv
458 ){
459   Decimal *pA = decimal_new(context, argv[0], 0, 0);
460   Decimal *pB = decimal_new(context, argv[1], 0, 0);
461   UNUSED_PARAMETER(argc);
462   if( pB==0 ) return;
463   pB->sign = !pB->sign;
464   decimal_add(pA, pB);
465   decimal_result(context, pA);
466   decimal_free(pA);
467   decimal_free(pB);
468 }
469 
470 /* Aggregate funcion:   decimal_sum(X)
471 **
472 ** Works like sum() except that it uses decimal arithmetic for unlimited
473 ** precision.
474 */
475 static void decimalSumStep(
476   sqlite3_context *context,
477   int argc,
478   sqlite3_value **argv
479 ){
480   Decimal *p;
481   Decimal *pArg;
482   UNUSED_PARAMETER(argc);
483   p = sqlite3_aggregate_context(context, sizeof(*p));
484   if( p==0 ) return;
485   if( !p->isInit ){
486     p->isInit = 1;
487     p->a = sqlite3_malloc(2);
488     if( p->a==0 ){
489       p->oom = 1;
490     }else{
491       p->a[0] = 0;
492     }
493     p->nDigit = 1;
494     p->nFrac = 0;
495   }
496   if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
497   pArg = decimal_new(context, argv[0], 0, 0);
498   decimal_add(p, pArg);
499   decimal_free(pArg);
500 }
501 static void decimalSumInverse(
502   sqlite3_context *context,
503   int argc,
504   sqlite3_value **argv
505 ){
506   Decimal *p;
507   Decimal *pArg;
508   UNUSED_PARAMETER(argc);
509   p = sqlite3_aggregate_context(context, sizeof(*p));
510   if( p==0 ) return;
511   if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return;
512   pArg = decimal_new(context, argv[0], 0, 0);
513   if( pArg ) pArg->sign = !pArg->sign;
514   decimal_add(p, pArg);
515   decimal_free(pArg);
516 }
517 static void decimalSumValue(sqlite3_context *context){
518   Decimal *p = sqlite3_aggregate_context(context, 0);
519   if( p==0 ) return;
520   decimal_result(context, p);
521 }
522 static void decimalSumFinalize(sqlite3_context *context){
523   Decimal *p = sqlite3_aggregate_context(context, 0);
524   if( p==0 ) return;
525   decimal_result(context, p);
526   decimal_clear(p);
527 }
528 
529 /*
530 ** SQL Function:   decimal_mul(X, Y)
531 **
532 ** Return the product of X and Y.
533 **
534 ** All significant digits after the decimal point are retained.
535 ** Trailing zeros after the decimal point are omitted as long as
536 ** the number of digits after the decimal point is no less than
537 ** either the number of digits in either input.
538 */
539 static void decimalMulFunc(
540   sqlite3_context *context,
541   int argc,
542   sqlite3_value **argv
543 ){
544   Decimal *pA = decimal_new(context, argv[0], 0, 0);
545   Decimal *pB = decimal_new(context, argv[1], 0, 0);
546   signed char *acc = 0;
547   int i, j, k;
548   int minFrac;
549   UNUSED_PARAMETER(argc);
550   if( pA==0 || pA->oom || pA->isNull
551    || pB==0 || pB->oom || pB->isNull
552   ){
553     goto mul_end;
554   }
555   acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 );
556   if( acc==0 ){
557     sqlite3_result_error_nomem(context);
558     goto mul_end;
559   }
560   memset(acc, 0, pA->nDigit + pB->nDigit + 2);
561   minFrac = pA->nFrac;
562   if( pB->nFrac<minFrac ) minFrac = pB->nFrac;
563   for(i=pA->nDigit-1; i>=0; i--){
564     signed char f = pA->a[i];
565     int carry = 0, x;
566     for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){
567       x = acc[k] + f*pB->a[j] + carry;
568       acc[k] = x%10;
569       carry = x/10;
570     }
571     x = acc[k] + carry;
572     acc[k] = x%10;
573     acc[k-1] += x/10;
574   }
575   sqlite3_free(pA->a);
576   pA->a = acc;
577   acc = 0;
578   pA->nDigit += pB->nDigit + 2;
579   pA->nFrac += pB->nFrac;
580   pA->sign ^= pB->sign;
581   while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){
582     pA->nFrac--;
583     pA->nDigit--;
584   }
585   decimal_result(context, pA);
586 
587 mul_end:
588   sqlite3_free(acc);
589   decimal_free(pA);
590   decimal_free(pB);
591 }
592 
593 #ifdef _WIN32
594 __declspec(dllexport)
595 #endif
596 int sqlite3_decimal_init(
597   sqlite3 *db,
598   char **pzErrMsg,
599   const sqlite3_api_routines *pApi
600 ){
601   int rc = SQLITE_OK;
602   static const struct {
603     const char *zFuncName;
604     int nArg;
605     void (*xFunc)(sqlite3_context*,int,sqlite3_value**);
606   } aFunc[] = {
607     { "decimal",       1,   decimalFunc        },
608     { "decimal_cmp",   2,   decimalCmpFunc     },
609     { "decimal_add",   2,   decimalAddFunc     },
610     { "decimal_sub",   2,   decimalSubFunc     },
611     { "decimal_mul",   2,   decimalMulFunc     },
612   };
613   unsigned int i;
614   (void)pzErrMsg;  /* Unused parameter */
615 
616   SQLITE_EXTENSION_INIT2(pApi);
617 
618   for(i=0; i<sizeof(aFunc)/sizeof(aFunc[0]) && rc==SQLITE_OK; i++){
619     rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg,
620                    SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC,
621                    0, aFunc[i].xFunc, 0, 0);
622   }
623   if( rc==SQLITE_OK ){
624     rc = sqlite3_create_window_function(db, "decimal_sum", 1,
625                    SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0,
626                    decimalSumStep, decimalSumFinalize,
627                    decimalSumValue, decimalSumInverse, 0);
628   }
629   if( rc==SQLITE_OK ){
630     rc = sqlite3_create_collation(db, "decimal", SQLITE_UTF8,
631                                   0, decimalCollFunc);
632   }
633   return rc;
634 }
635