@@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
32213221}
32223222
32233223
3224+ /* Number of permutations and combinations.
3225+ * P(n, k) = n! / (n-k)!
3226+ * C(n, k) = P(n, k) / k!
3227+ */
3228+
3229+ /* Calculate C(n, k) for n in the 63-bit range. */
3230+ static PyObject *
3231+ perm_comb_small (unsigned long long n , unsigned long long k , int iscomb )
3232+ {
3233+ /* long long is at least 64 bit */
3234+ static const unsigned long long fast_comb_limits [] = {
3235+ 0 , ULLONG_MAX , 4294967296ULL , 3329022 , 102570 , 13467 , 3612 , 1449 , // 0-7
3236+ 746 , 453 , 308 , 227 , 178 , 147 , 125 , 110 , // 8-15
3237+ 99 , 90 , 84 , 79 , 75 , 72 , 69 , 68 , // 16-23
3238+ 66 , 65 , 64 , 63 , 63 , 62 , 62 , 62 , // 24-31
3239+ };
3240+ static const unsigned long long fast_perm_limits [] = {
3241+ 0 , ULLONG_MAX , 4294967296ULL , 2642246 , 65537 , 7133 , 1627 , 568 , // 0-7
3242+ 259 , 142 , 88 , 61 , 45 , 36 , 30 , // 8-14
3243+ };
3244+
3245+ if (k == 0 ) {
3246+ return PyLong_FromLong (1 );
3247+ }
3248+
3249+ /* For small enough n and k the result fits in the 64-bit range and can
3250+ * be calculated without allocating intermediate PyLong objects. */
3251+ if (iscomb
3252+ ? (k < Py_ARRAY_LENGTH (fast_comb_limits )
3253+ && n <= fast_comb_limits [k ])
3254+ : (k < Py_ARRAY_LENGTH (fast_perm_limits )
3255+ && n <= fast_perm_limits [k ]))
3256+ {
3257+ unsigned long long result = n ;
3258+ if (iscomb ) {
3259+ for (unsigned long long i = 1 ; i < k ;) {
3260+ result *= -- n ;
3261+ result /= ++ i ;
3262+ }
3263+ }
3264+ else {
3265+ for (unsigned long long i = 1 ; i < k ;) {
3266+ result *= -- n ;
3267+ ++ i ;
3268+ }
3269+ }
3270+ return PyLong_FromUnsignedLongLong (result );
3271+ }
3272+
3273+ /* For larger n use recursive formula. */
3274+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
3275+ unsigned long long j = k / 2 ;
3276+ PyObject * a , * b ;
3277+ a = perm_comb_small (n , j , iscomb );
3278+ if (a == NULL ) {
3279+ return NULL ;
3280+ }
3281+ b = perm_comb_small (n - j , k - j , iscomb );
3282+ if (b == NULL ) {
3283+ goto error ;
3284+ }
3285+ Py_SETREF (a , PyNumber_Multiply (a , b ));
3286+ Py_DECREF (b );
3287+ if (iscomb && a != NULL ) {
3288+ b = perm_comb_small (k , j , 1 );
3289+ if (b == NULL ) {
3290+ goto error ;
3291+ }
3292+ Py_SETREF (a , PyNumber_FloorDivide (a , b ));
3293+ Py_DECREF (b );
3294+ }
3295+ return a ;
3296+
3297+ error :
3298+ Py_DECREF (a );
3299+ return NULL ;
3300+ }
3301+
3302+ /* Calculate P(n, k) or C(n, k) using recursive formulas.
3303+ * It is more efficient than sequential multiplication thanks to
3304+ * Karatsuba multiplication.
3305+ */
3306+ static PyObject *
3307+ perm_comb (PyObject * n , unsigned long long k , int iscomb )
3308+ {
3309+ if (k == 0 ) {
3310+ return PyLong_FromLong (1 );
3311+ }
3312+ if (k == 1 ) {
3313+ Py_INCREF (n );
3314+ return n ;
3315+ }
3316+
3317+ /* P(n, k) = P(n, j) * P(n-j, k-j) */
3318+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
3319+ unsigned long long j = k / 2 ;
3320+ PyObject * a , * b ;
3321+ a = perm_comb (n , j , iscomb );
3322+ if (a == NULL ) {
3323+ return NULL ;
3324+ }
3325+ PyObject * t = PyLong_FromUnsignedLongLong (j );
3326+ if (t == NULL ) {
3327+ goto error ;
3328+ }
3329+ n = PyNumber_Subtract (n , t );
3330+ Py_DECREF (t );
3331+ if (n == NULL ) {
3332+ goto error ;
3333+ }
3334+ b = perm_comb (n , k - j , iscomb );
3335+ Py_DECREF (n );
3336+ if (b == NULL ) {
3337+ goto error ;
3338+ }
3339+ Py_SETREF (a , PyNumber_Multiply (a , b ));
3340+ Py_DECREF (b );
3341+ if (iscomb && a != NULL ) {
3342+ b = perm_comb_small (k , j , 1 );
3343+ if (b == NULL ) {
3344+ goto error ;
3345+ }
3346+ Py_SETREF (a , PyNumber_FloorDivide (a , b ));
3347+ Py_DECREF (b );
3348+ }
3349+ return a ;
3350+
3351+ error :
3352+ Py_DECREF (a );
3353+ return NULL ;
3354+ }
3355+
32243356/*[clinic input]
32253357math.perm
32263358
@@ -3244,9 +3376,9 @@ static PyObject *
32443376math_perm_impl (PyObject * module , PyObject * n , PyObject * k )
32453377/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
32463378{
3247- PyObject * result = NULL , * factor = NULL ;
3379+ PyObject * result = NULL ;
32483380 int overflow , cmp ;
3249- long long i , factors ;
3381+ long long ki , ni ;
32503382
32513383 if (k == Py_None ) {
32523384 return math_factorial (module , n );
@@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
32603392 Py_DECREF (n );
32613393 return NULL ;
32623394 }
3395+ assert (PyLong_CheckExact (n ) && PyLong_CheckExact (k ));
32633396
32643397 if (Py_SIZE (n ) < 0 ) {
32653398 PyErr_SetString (PyExc_ValueError ,
@@ -3281,57 +3414,38 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
32813414 goto error ;
32823415 }
32833416
3284- factors = PyLong_AsLongLongAndOverflow (k , & overflow );
3417+ ki = PyLong_AsLongLongAndOverflow (k , & overflow );
3418+ assert (overflow >= 0 && !PyErr_Occurred ());
32853419 if (overflow > 0 ) {
32863420 PyErr_Format (PyExc_OverflowError ,
32873421 "k must not exceed %lld" ,
32883422 LLONG_MAX );
32893423 goto error ;
32903424 }
3291- else if (factors == -1 ) {
3292- /* k is nonnegative, so a return value of -1 can only indicate error */
3293- goto error ;
3294- }
3425+ assert (ki >= 0 );
32953426
3296- if (factors == 0 ) {
3297- result = PyLong_FromLong (1 );
3298- goto done ;
3427+ ni = PyLong_AsLongLongAndOverflow (n , & overflow );
3428+ assert (overflow >= 0 && !PyErr_Occurred ());
3429+ if (!overflow && ki > 1 ) {
3430+ assert (ni >= 0 );
3431+ result = perm_comb_small ((unsigned long long )ni ,
3432+ (unsigned long long )ki , 0 );
32993433 }
3300-
3301- result = n ;
3302- Py_INCREF (result );
3303- if (factors == 1 ) {
3304- goto done ;
3305- }
3306-
3307- factor = Py_NewRef (n );
3308- PyObject * one = _PyLong_GetOne (); // borrowed ref
3309- for (i = 1 ; i < factors ; ++ i ) {
3310- Py_SETREF (factor , PyNumber_Subtract (factor , one ));
3311- if (factor == NULL ) {
3312- goto error ;
3313- }
3314- Py_SETREF (result , PyNumber_Multiply (result , factor ));
3315- if (result == NULL ) {
3316- goto error ;
3317- }
3434+ else {
3435+ result = perm_comb (n , (unsigned long long )ki , 0 );
33183436 }
3319- Py_DECREF (factor );
33203437
33213438done :
33223439 Py_DECREF (n );
33233440 Py_DECREF (k );
33243441 return result ;
33253442
33263443error :
3327- Py_XDECREF (factor );
3328- Py_XDECREF (result );
33293444 Py_DECREF (n );
33303445 Py_DECREF (k );
33313446 return NULL ;
33323447}
33333448
3334-
33353449/*[clinic input]
33363450math.comb
33373451
@@ -3357,9 +3471,9 @@ static PyObject *
33573471math_comb_impl (PyObject * module , PyObject * n , PyObject * k )
33583472/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
33593473{
3360- PyObject * result = NULL , * factor = NULL , * temp ;
3474+ PyObject * result = NULL , * temp ;
33613475 int overflow , cmp ;
3362- long long i , factors ;
3476+ long long ki , ni ;
33633477
33643478 n = PyNumber_Index (n );
33653479 if (n == NULL ) {
@@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
33703484 Py_DECREF (n );
33713485 return NULL ;
33723486 }
3487+ assert (PyLong_CheckExact (n ) && PyLong_CheckExact (k ));
33733488
33743489 if (Py_SIZE (n ) < 0 ) {
33753490 PyErr_SetString (PyExc_ValueError ,
@@ -3382,82 +3497,66 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
33823497 goto error ;
33833498 }
33843499
3385- /* k = min(k, n - k) */
3386- temp = PyNumber_Subtract (n , k );
3387- if (temp == NULL ) {
3388- goto error ;
3389- }
3390- if (Py_SIZE (temp ) < 0 ) {
3391- Py_DECREF (temp );
3392- result = PyLong_FromLong (0 );
3393- goto done ;
3394- }
3395- cmp = PyObject_RichCompareBool (temp , k , Py_LT );
3396- if (cmp > 0 ) {
3397- Py_SETREF (k , temp );
3500+ ni = PyLong_AsLongLongAndOverflow (n , & overflow );
3501+ assert (overflow >= 0 && !PyErr_Occurred ());
3502+ if (!overflow ) {
3503+ assert (ni >= 0 );
3504+ ki = PyLong_AsLongLongAndOverflow (k , & overflow );
3505+ assert (overflow >= 0 && !PyErr_Occurred ());
3506+ if (overflow || ki > ni ) {
3507+ result = PyLong_FromLong (0 );
3508+ goto done ;
3509+ }
3510+ assert (ki >= 0 );
3511+ ki = Py_MIN (ki , ni - ki );
3512+ if (ki > 1 ) {
3513+ result = perm_comb_small ((unsigned long long )ni ,
3514+ (unsigned long long )ki , 1 );
3515+ goto done ;
3516+ }
3517+ /* For k == 1 just return the original n in perm_comb(). */
33983518 }
33993519 else {
3400- Py_DECREF (temp );
3401- if (cmp < 0 ) {
3520+ /* k = min(k, n - k) */
3521+ temp = PyNumber_Subtract (n , k );
3522+ if (temp == NULL ) {
34023523 goto error ;
34033524 }
3404- }
3405-
3406- factors = PyLong_AsLongLongAndOverflow (k , & overflow );
3407- if (overflow > 0 ) {
3408- PyErr_Format (PyExc_OverflowError ,
3409- "min(n - k, k) must not exceed %lld" ,
3410- LLONG_MAX );
3411- goto error ;
3412- }
3413- if (factors == -1 ) {
3414- /* k is nonnegative, so a return value of -1 can only indicate error */
3415- goto error ;
3416- }
3417-
3418- if (factors == 0 ) {
3419- result = PyLong_FromLong (1 );
3420- goto done ;
3421- }
3422-
3423- result = n ;
3424- Py_INCREF (result );
3425- if (factors == 1 ) {
3426- goto done ;
3427- }
3428-
3429- factor = Py_NewRef (n );
3430- PyObject * one = _PyLong_GetOne (); // borrowed ref
3431- for (i = 1 ; i < factors ; ++ i ) {
3432- Py_SETREF (factor , PyNumber_Subtract (factor , one ));
3433- if (factor == NULL ) {
3434- goto error ;
3525+ if (Py_SIZE (temp ) < 0 ) {
3526+ Py_DECREF (temp );
3527+ result = PyLong_FromLong (0 );
3528+ goto done ;
34353529 }
3436- Py_SETREF ( result , PyNumber_Multiply ( result , factor ) );
3437- if (result == NULL ) {
3438- goto error ;
3530+ cmp = PyObject_RichCompareBool ( temp , k , Py_LT );
3531+ if (cmp > 0 ) {
3532+ Py_SETREF ( k , temp ) ;
34393533 }
3440-
3441- temp = PyLong_FromUnsignedLongLong ((unsigned long long )i + 1 );
3442- if (temp == NULL ) {
3443- goto error ;
3534+ else {
3535+ Py_DECREF (temp );
3536+ if (cmp < 0 ) {
3537+ goto error ;
3538+ }
34443539 }
3445- Py_SETREF (result , PyNumber_FloorDivide (result , temp ));
3446- Py_DECREF (temp );
3447- if (result == NULL ) {
3540+
3541+ ki = PyLong_AsLongLongAndOverflow (k , & overflow );
3542+ assert (overflow >= 0 && !PyErr_Occurred ());
3543+ if (overflow ) {
3544+ PyErr_Format (PyExc_OverflowError ,
3545+ "min(n - k, k) must not exceed %lld" ,
3546+ LLONG_MAX );
34483547 goto error ;
34493548 }
3549+ assert (ki >= 0 );
34503550 }
3451- Py_DECREF (factor );
3551+
3552+ result = perm_comb (n , (unsigned long long )ki , 1 );
34523553
34533554done :
34543555 Py_DECREF (n );
34553556 Py_DECREF (k );
34563557 return result ;
34573558
34583559error :
3459- Py_XDECREF (factor );
3460- Py_XDECREF (result );
34613560 Py_DECREF (n );
34623561 Py_DECREF (k );
34633562 return NULL ;
0 commit comments