1 /* 2 ** 2020-06-22 3 ** 4 ** The author disclaims copyright to this source code. In place of 5 ** a legal notice, here is a blessing: 6 ** 7 ** May you do good and not evil. 8 ** May you find forgiveness for yourself and forgive others. 9 ** May you share freely, never taking more than you give. 10 ** 11 ****************************************************************************** 12 ** 13 ** Routines to implement arbitrary-precision decimal math. 14 ** 15 ** The focus here is on simplicity and correctness, not performance. 16 */ 17 #include "sqlite3ext.h" 18 SQLITE_EXTENSION_INIT1 19 #include <assert.h> 20 #include <string.h> 21 #include <ctype.h> 22 #include <stdlib.h> 23 24 /* Mark a function parameter as unused, to suppress nuisance compiler 25 ** warnings. */ 26 #ifndef UNUSED_PARAMETER 27 # define UNUSED_PARAMETER(X) (void)(X) 28 #endif 29 30 31 /* A decimal object */ 32 typedef struct Decimal Decimal; 33 struct Decimal { 34 char sign; /* 0 for positive, 1 for negative */ 35 char oom; /* True if an OOM is encountered */ 36 char isNull; /* True if holds a NULL rather than a number */ 37 char isInit; /* True upon initialization */ 38 int nDigit; /* Total number of digits */ 39 int nFrac; /* Number of digits to the right of the decimal point */ 40 signed char *a; /* Array of digits. Most significant first. */ 41 }; 42 43 /* 44 ** Release memory held by a Decimal, but do not free the object itself. 45 */ 46 static void decimal_clear(Decimal *p){ 47 sqlite3_free(p->a); 48 } 49 50 /* 51 ** Destroy a Decimal object 52 */ 53 static void decimal_free(Decimal *p){ 54 if( p ){ 55 decimal_clear(p); 56 sqlite3_free(p); 57 } 58 } 59 60 /* 61 ** Allocate a new Decimal object. Initialize it to the number given 62 ** by the input string. 63 */ 64 static Decimal *decimal_new( 65 sqlite3_context *pCtx, 66 sqlite3_value *pIn, 67 int nAlt, 68 const unsigned char *zAlt 69 ){ 70 Decimal *p; 71 int n, i; 72 const unsigned char *zIn; 73 int iExp = 0; 74 p = sqlite3_malloc( sizeof(*p) ); 75 if( p==0 ) goto new_no_mem; 76 p->sign = 0; 77 p->oom = 0; 78 p->isInit = 1; 79 p->isNull = 0; 80 p->nDigit = 0; 81 p->nFrac = 0; 82 if( zAlt ){ 83 n = nAlt, 84 zIn = zAlt; 85 }else{ 86 if( sqlite3_value_type(pIn)==SQLITE_NULL ){ 87 p->a = 0; 88 p->isNull = 1; 89 return p; 90 } 91 n = sqlite3_value_bytes(pIn); 92 zIn = sqlite3_value_text(pIn); 93 } 94 p->a = sqlite3_malloc64( n+1 ); 95 if( p->a==0 ) goto new_no_mem; 96 for(i=0; isspace(zIn[i]); i++){} 97 if( zIn[i]=='-' ){ 98 p->sign = 1; 99 i++; 100 }else if( zIn[i]=='+' ){ 101 i++; 102 } 103 while( i<n && zIn[i]=='0' ) i++; 104 while( i<n ){ 105 char c = zIn[i]; 106 if( c>='0' && c<='9' ){ 107 p->a[p->nDigit++] = c - '0'; 108 }else if( c=='.' ){ 109 p->nFrac = p->nDigit + 1; 110 }else if( c=='e' || c=='E' ){ 111 int j = i+1; 112 int neg = 0; 113 if( j>=n ) break; 114 if( zIn[j]=='-' ){ 115 neg = 1; 116 j++; 117 }else if( zIn[j]=='+' ){ 118 j++; 119 } 120 while( j<n && iExp<1000000 ){ 121 if( zIn[j]>='0' && zIn[j]<='9' ){ 122 iExp = iExp*10 + zIn[j] - '0'; 123 } 124 j++; 125 } 126 if( neg ) iExp = -iExp; 127 break; 128 } 129 i++; 130 } 131 if( p->nFrac ){ 132 p->nFrac = p->nDigit - (p->nFrac - 1); 133 } 134 if( iExp>0 ){ 135 if( p->nFrac>0 ){ 136 if( iExp<=p->nFrac ){ 137 p->nFrac -= iExp; 138 iExp = 0; 139 }else{ 140 iExp -= p->nFrac; 141 p->nFrac = 0; 142 } 143 } 144 if( iExp>0 ){ 145 p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 ); 146 if( p->a==0 ) goto new_no_mem; 147 memset(p->a+p->nDigit, 0, iExp); 148 p->nDigit += iExp; 149 } 150 }else if( iExp<0 ){ 151 int nExtra; 152 iExp = -iExp; 153 nExtra = p->nDigit - p->nFrac - 1; 154 if( nExtra ){ 155 if( nExtra>=iExp ){ 156 p->nFrac += iExp; 157 iExp = 0; 158 }else{ 159 iExp -= nExtra; 160 p->nFrac = p->nDigit - 1; 161 } 162 } 163 if( iExp>0 ){ 164 p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 ); 165 if( p->a==0 ) goto new_no_mem; 166 memmove(p->a+iExp, p->a, p->nDigit); 167 memset(p->a, 0, iExp); 168 p->nDigit += iExp; 169 p->nFrac += iExp; 170 } 171 } 172 return p; 173 174 new_no_mem: 175 if( pCtx ) sqlite3_result_error_nomem(pCtx); 176 sqlite3_free(p); 177 return 0; 178 } 179 180 /* 181 ** Make the given Decimal the result. 182 */ 183 static void decimal_result(sqlite3_context *pCtx, Decimal *p){ 184 char *z; 185 int i, j; 186 int n; 187 if( p==0 || p->oom ){ 188 sqlite3_result_error_nomem(pCtx); 189 return; 190 } 191 if( p->isNull ){ 192 sqlite3_result_null(pCtx); 193 return; 194 } 195 z = sqlite3_malloc( p->nDigit+4 ); 196 if( z==0 ){ 197 sqlite3_result_error_nomem(pCtx); 198 return; 199 } 200 i = 0; 201 if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){ 202 p->sign = 0; 203 } 204 if( p->sign ){ 205 z[0] = '-'; 206 i = 1; 207 } 208 n = p->nDigit - p->nFrac; 209 if( n<=0 ){ 210 z[i++] = '0'; 211 } 212 j = 0; 213 while( n>1 && p->a[j]==0 ){ 214 j++; 215 n--; 216 } 217 while( n>0 ){ 218 z[i++] = p->a[j] + '0'; 219 j++; 220 n--; 221 } 222 if( p->nFrac ){ 223 z[i++] = '.'; 224 do{ 225 z[i++] = p->a[j] + '0'; 226 j++; 227 }while( j<p->nDigit ); 228 } 229 z[i] = 0; 230 sqlite3_result_text(pCtx, z, i, sqlite3_free); 231 } 232 233 /* 234 ** SQL Function: decimal(X) 235 ** 236 ** Convert input X into decimal and then back into text 237 */ 238 static void decimalFunc( 239 sqlite3_context *context, 240 int argc, 241 sqlite3_value **argv 242 ){ 243 Decimal *p = decimal_new(context, argv[0], 0, 0); 244 UNUSED_PARAMETER(argc); 245 decimal_result(context, p); 246 decimal_free(p); 247 } 248 249 /* 250 ** Compare to Decimal objects. Return negative, 0, or positive if the 251 ** first object is less than, equal to, or greater than the second. 252 ** 253 ** Preconditions for this routine: 254 ** 255 ** pA!=0 256 ** pA->isNull==0 257 ** pB!=0 258 ** pB->isNull==0 259 */ 260 static int decimal_cmp(const Decimal *pA, const Decimal *pB){ 261 int nASig, nBSig, rc, n; 262 if( pA->sign!=pB->sign ){ 263 return pA->sign ? -1 : +1; 264 } 265 if( pA->sign ){ 266 const Decimal *pTemp = pA; 267 pA = pB; 268 pB = pTemp; 269 } 270 nASig = pA->nDigit - pA->nFrac; 271 nBSig = pB->nDigit - pB->nFrac; 272 if( nASig!=nBSig ){ 273 return nASig - nBSig; 274 } 275 n = pA->nDigit; 276 if( n>pB->nDigit ) n = pB->nDigit; 277 rc = memcmp(pA->a, pB->a, n); 278 if( rc==0 ){ 279 rc = pA->nDigit - pB->nDigit; 280 } 281 return rc; 282 } 283 284 /* 285 ** SQL Function: decimal_cmp(X, Y) 286 ** 287 ** Return negative, zero, or positive if X is less then, equal to, or 288 ** greater than Y. 289 */ 290 static void decimalCmpFunc( 291 sqlite3_context *context, 292 int argc, 293 sqlite3_value **argv 294 ){ 295 Decimal *pA = 0, *pB = 0; 296 int rc; 297 298 UNUSED_PARAMETER(argc); 299 pA = decimal_new(context, argv[0], 0, 0); 300 if( pA==0 || pA->isNull ) goto cmp_done; 301 pB = decimal_new(context, argv[1], 0, 0); 302 if( pB==0 || pB->isNull ) goto cmp_done; 303 rc = decimal_cmp(pA, pB); 304 if( rc<0 ) rc = -1; 305 else if( rc>0 ) rc = +1; 306 sqlite3_result_int(context, rc); 307 cmp_done: 308 decimal_free(pA); 309 decimal_free(pB); 310 } 311 312 /* 313 ** Expand the Decimal so that it has a least nDigit digits and nFrac 314 ** digits to the right of the decimal point. 315 */ 316 static void decimal_expand(Decimal *p, int nDigit, int nFrac){ 317 int nAddSig; 318 int nAddFrac; 319 if( p==0 ) return; 320 nAddFrac = nFrac - p->nFrac; 321 nAddSig = (nDigit - p->nDigit) - nAddFrac; 322 if( nAddFrac==0 && nAddSig==0 ) return; 323 p->a = sqlite3_realloc64(p->a, nDigit+1); 324 if( p->a==0 ){ 325 p->oom = 1; 326 return; 327 } 328 if( nAddSig ){ 329 memmove(p->a+nAddSig, p->a, p->nDigit); 330 memset(p->a, 0, nAddSig); 331 p->nDigit += nAddSig; 332 } 333 if( nAddFrac ){ 334 memset(p->a+p->nDigit, 0, nAddFrac); 335 p->nDigit += nAddFrac; 336 p->nFrac += nAddFrac; 337 } 338 } 339 340 /* 341 ** Add the value pB into pA. 342 ** 343 ** Both pA and pB might become denormalized by this routine. 344 */ 345 static void decimal_add(Decimal *pA, Decimal *pB){ 346 int nSig, nFrac, nDigit; 347 int i, rc; 348 if( pA==0 ){ 349 return; 350 } 351 if( pA->oom || pB==0 || pB->oom ){ 352 pA->oom = 1; 353 return; 354 } 355 if( pA->isNull || pB->isNull ){ 356 pA->isNull = 1; 357 return; 358 } 359 nSig = pA->nDigit - pA->nFrac; 360 if( nSig && pA->a[0]==0 ) nSig--; 361 if( nSig<pB->nDigit-pB->nFrac ){ 362 nSig = pB->nDigit - pB->nFrac; 363 } 364 nFrac = pA->nFrac; 365 if( nFrac<pB->nFrac ) nFrac = pB->nFrac; 366 nDigit = nSig + nFrac + 1; 367 decimal_expand(pA, nDigit, nFrac); 368 decimal_expand(pB, nDigit, nFrac); 369 if( pA->oom || pB->oom ){ 370 pA->oom = 1; 371 }else{ 372 if( pA->sign==pB->sign ){ 373 int carry = 0; 374 for(i=nDigit-1; i>=0; i--){ 375 int x = pA->a[i] + pB->a[i] + carry; 376 if( x>=10 ){ 377 carry = 1; 378 pA->a[i] = x - 10; 379 }else{ 380 carry = 0; 381 pA->a[i] = x; 382 } 383 } 384 }else{ 385 signed char *aA, *aB; 386 int borrow = 0; 387 rc = memcmp(pA->a, pB->a, nDigit); 388 if( rc<0 ){ 389 aA = pB->a; 390 aB = pA->a; 391 pA->sign = !pA->sign; 392 }else{ 393 aA = pA->a; 394 aB = pB->a; 395 } 396 for(i=nDigit-1; i>=0; i--){ 397 int x = aA[i] - aB[i] - borrow; 398 if( x<0 ){ 399 pA->a[i] = x+10; 400 borrow = 1; 401 }else{ 402 pA->a[i] = x; 403 borrow = 0; 404 } 405 } 406 } 407 } 408 } 409 410 /* 411 ** Compare text in decimal order. 412 */ 413 static int decimalCollFunc( 414 void *notUsed, 415 int nKey1, const void *pKey1, 416 int nKey2, const void *pKey2 417 ){ 418 const unsigned char *zA = (const unsigned char*)pKey1; 419 const unsigned char *zB = (const unsigned char*)pKey2; 420 Decimal *pA = decimal_new(0, 0, nKey1, zA); 421 Decimal *pB = decimal_new(0, 0, nKey2, zB); 422 int rc; 423 UNUSED_PARAMETER(notUsed); 424 if( pA==0 || pB==0 ){ 425 rc = 0; 426 }else{ 427 rc = decimal_cmp(pA, pB); 428 } 429 decimal_free(pA); 430 decimal_free(pB); 431 return rc; 432 } 433 434 435 /* 436 ** SQL Function: decimal_add(X, Y) 437 ** decimal_sub(X, Y) 438 ** 439 ** Return the sum or difference of X and Y. 440 */ 441 static void decimalAddFunc( 442 sqlite3_context *context, 443 int argc, 444 sqlite3_value **argv 445 ){ 446 Decimal *pA = decimal_new(context, argv[0], 0, 0); 447 Decimal *pB = decimal_new(context, argv[1], 0, 0); 448 UNUSED_PARAMETER(argc); 449 decimal_add(pA, pB); 450 decimal_result(context, pA); 451 decimal_free(pA); 452 decimal_free(pB); 453 } 454 static void decimalSubFunc( 455 sqlite3_context *context, 456 int argc, 457 sqlite3_value **argv 458 ){ 459 Decimal *pA = decimal_new(context, argv[0], 0, 0); 460 Decimal *pB = decimal_new(context, argv[1], 0, 0); 461 UNUSED_PARAMETER(argc); 462 if( pB==0 ) return; 463 pB->sign = !pB->sign; 464 decimal_add(pA, pB); 465 decimal_result(context, pA); 466 decimal_free(pA); 467 decimal_free(pB); 468 } 469 470 /* Aggregate funcion: decimal_sum(X) 471 ** 472 ** Works like sum() except that it uses decimal arithmetic for unlimited 473 ** precision. 474 */ 475 static void decimalSumStep( 476 sqlite3_context *context, 477 int argc, 478 sqlite3_value **argv 479 ){ 480 Decimal *p; 481 Decimal *pArg; 482 UNUSED_PARAMETER(argc); 483 p = sqlite3_aggregate_context(context, sizeof(*p)); 484 if( p==0 ) return; 485 if( !p->isInit ){ 486 p->isInit = 1; 487 p->a = sqlite3_malloc(2); 488 if( p->a==0 ){ 489 p->oom = 1; 490 }else{ 491 p->a[0] = 0; 492 } 493 p->nDigit = 1; 494 p->nFrac = 0; 495 } 496 if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return; 497 pArg = decimal_new(context, argv[0], 0, 0); 498 decimal_add(p, pArg); 499 decimal_free(pArg); 500 } 501 static void decimalSumInverse( 502 sqlite3_context *context, 503 int argc, 504 sqlite3_value **argv 505 ){ 506 Decimal *p; 507 Decimal *pArg; 508 UNUSED_PARAMETER(argc); 509 p = sqlite3_aggregate_context(context, sizeof(*p)); 510 if( p==0 ) return; 511 if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return; 512 pArg = decimal_new(context, argv[0], 0, 0); 513 if( pArg ) pArg->sign = !pArg->sign; 514 decimal_add(p, pArg); 515 decimal_free(pArg); 516 } 517 static void decimalSumValue(sqlite3_context *context){ 518 Decimal *p = sqlite3_aggregate_context(context, 0); 519 if( p==0 ) return; 520 decimal_result(context, p); 521 } 522 static void decimalSumFinalize(sqlite3_context *context){ 523 Decimal *p = sqlite3_aggregate_context(context, 0); 524 if( p==0 ) return; 525 decimal_result(context, p); 526 decimal_clear(p); 527 } 528 529 /* 530 ** SQL Function: decimal_mul(X, Y) 531 ** 532 ** Return the product of X and Y. 533 ** 534 ** All significant digits after the decimal point are retained. 535 ** Trailing zeros after the decimal point are omitted as long as 536 ** the number of digits after the decimal point is no less than 537 ** either the number of digits in either input. 538 */ 539 static void decimalMulFunc( 540 sqlite3_context *context, 541 int argc, 542 sqlite3_value **argv 543 ){ 544 Decimal *pA = decimal_new(context, argv[0], 0, 0); 545 Decimal *pB = decimal_new(context, argv[1], 0, 0); 546 signed char *acc = 0; 547 int i, j, k; 548 int minFrac; 549 UNUSED_PARAMETER(argc); 550 if( pA==0 || pA->oom || pA->isNull 551 || pB==0 || pB->oom || pB->isNull 552 ){ 553 goto mul_end; 554 } 555 acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 ); 556 if( acc==0 ){ 557 sqlite3_result_error_nomem(context); 558 goto mul_end; 559 } 560 memset(acc, 0, pA->nDigit + pB->nDigit + 2); 561 minFrac = pA->nFrac; 562 if( pB->nFrac<minFrac ) minFrac = pB->nFrac; 563 for(i=pA->nDigit-1; i>=0; i--){ 564 signed char f = pA->a[i]; 565 int carry = 0, x; 566 for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){ 567 x = acc[k] + f*pB->a[j] + carry; 568 acc[k] = x%10; 569 carry = x/10; 570 } 571 x = acc[k] + carry; 572 acc[k] = x%10; 573 acc[k-1] += x/10; 574 } 575 sqlite3_free(pA->a); 576 pA->a = acc; 577 acc = 0; 578 pA->nDigit += pB->nDigit + 2; 579 pA->nFrac += pB->nFrac; 580 pA->sign ^= pB->sign; 581 while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){ 582 pA->nFrac--; 583 pA->nDigit--; 584 } 585 decimal_result(context, pA); 586 587 mul_end: 588 sqlite3_free(acc); 589 decimal_free(pA); 590 decimal_free(pB); 591 } 592 593 #ifdef _WIN32 594 __declspec(dllexport) 595 #endif 596 int sqlite3_decimal_init( 597 sqlite3 *db, 598 char **pzErrMsg, 599 const sqlite3_api_routines *pApi 600 ){ 601 int rc = SQLITE_OK; 602 static const struct { 603 const char *zFuncName; 604 int nArg; 605 void (*xFunc)(sqlite3_context*,int,sqlite3_value**); 606 } aFunc[] = { 607 { "decimal", 1, decimalFunc }, 608 { "decimal_cmp", 2, decimalCmpFunc }, 609 { "decimal_add", 2, decimalAddFunc }, 610 { "decimal_sub", 2, decimalSubFunc }, 611 { "decimal_mul", 2, decimalMulFunc }, 612 }; 613 unsigned int i; 614 (void)pzErrMsg; /* Unused parameter */ 615 616 SQLITE_EXTENSION_INIT2(pApi); 617 618 for(i=0; i<sizeof(aFunc)/sizeof(aFunc[0]) && rc==SQLITE_OK; i++){ 619 rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg, 620 SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 621 0, aFunc[i].xFunc, 0, 0); 622 } 623 if( rc==SQLITE_OK ){ 624 rc = sqlite3_create_window_function(db, "decimal_sum", 1, 625 SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0, 626 decimalSumStep, decimalSumFinalize, 627 decimalSumValue, decimalSumInverse, 0); 628 } 629 if( rc==SQLITE_OK ){ 630 rc = sqlite3_create_collation(db, "decimal", SQLITE_UTF8, 631 0, decimalCollFunc); 632 } 633 return rc; 634 } 635