@@ -240,10 +240,6 @@ w_short_pstring(const void *s, Py_ssize_t n, WFILE *p)
240240#define PyLong_MARSHAL_SHIFT 15
241241#define PyLong_MARSHAL_BASE ((short)1 << PyLong_MARSHAL_SHIFT)
242242#define PyLong_MARSHAL_MASK (PyLong_MARSHAL_BASE - 1)
243- #if PyLong_SHIFT % PyLong_MARSHAL_SHIFT != 0
244- #error "PyLong_SHIFT must be a multiple of PyLong_MARSHAL_SHIFT"
245- #endif
246- #define PyLong_MARSHAL_RATIO (PyLong_SHIFT / PyLong_MARSHAL_SHIFT)
247243
248244#define W_TYPE (t , p ) do { \
249245 w_byte((t) | flag, (p)); \
@@ -252,47 +248,106 @@ w_short_pstring(const void *s, Py_ssize_t n, WFILE *p)
252248static PyObject *
253249_PyMarshal_WriteObjectToString (PyObject * x , int version , int allow_code );
254250
251+ #define _r_digits (bitsize ) \
252+ static void \
253+ _r_digits##bitsize(const uint ## bitsize ## _t *digits, Py_ssize_t n, \
254+ uint8_t negative, Py_ssize_t marshal_ratio, WFILE *p) \
255+ { \
256+ /* set l to number of base PyLong_MARSHAL_BASE digits */ \
257+ Py_ssize_t l = (n - 1 )* marshal_ratio ; \
258+ uint ## bitsize ## _t d = digits[n - 1]; \
259+ \
260+ assert(marshal_ratio > 0); \
261+ assert(n >= 1); \
262+ assert(d != 0); /* a PyLong is always normalized */ \
263+ do { \
264+ d >>= PyLong_MARSHAL_SHIFT ; \
265+ l ++ ; \
266+ } while (d != 0 ); \
267+ if (l > SIZE32_MAX ) { \
268+ p -> depth -- ; \
269+ p -> error = WFERR_UNMARSHALLABLE ; \
270+ return ; \
271+ } \
272+ w_long ((long )(negative ? - l : l ), p ); \
273+ \
274+ for (Py_ssize_t i = 0 ; i < n - 1 ; i ++ ) { \
275+ d = digits [i ]; \
276+ for (Py_ssize_t j = 0 ; j < marshal_ratio ; j ++ ) { \
277+ w_short (d & PyLong_MARSHAL_MASK , p ); \
278+ d >>= PyLong_MARSHAL_SHIFT ; \
279+ } \
280+ assert (d == 0 ); \
281+ } \
282+ d = digits [n - 1 ]; \
283+ do { \
284+ w_short (d & PyLong_MARSHAL_MASK , p ); \
285+ d >>= PyLong_MARSHAL_SHIFT ; \
286+ } while (d != 0 ); \
287+ }
288+ _r_digits (16 )
289+ _r_digits (32 )
290+ #undef _r_digits
291+
255292static void
256293w_PyLong (const PyLongObject * ob , char flag , WFILE * p )
257294{
258- Py_ssize_t i , j , n , l ;
259- digit d ;
260-
261295 W_TYPE (TYPE_LONG , p );
262296 if (_PyLong_IsZero (ob )) {
263297 w_long ((long )0 , p );
264298 return ;
265299 }
266300
267- /* set l to number of base PyLong_MARSHAL_BASE digits */
268- n = _PyLong_DigitCount (ob );
269- l = (n - 1 ) * PyLong_MARSHAL_RATIO ;
270- d = ob -> long_value .ob_digit [n - 1 ];
271- assert (d != 0 ); /* a PyLong is always normalized */
272- do {
273- d >>= PyLong_MARSHAL_SHIFT ;
274- l ++ ;
275- } while (d != 0 );
276- if (l > SIZE32_MAX ) {
301+ PyLongExport long_export ;
302+
303+ if (PyLong_Export ((PyObject * )ob , & long_export ) < 0 ) {
277304 p -> depth -- ;
278305 p -> error = WFERR_UNMARSHALLABLE ;
279306 return ;
280307 }
281- w_long ((long )(_PyLong_IsNegative (ob ) ? - l : l ), p );
308+ if (!long_export .digits ) {
309+ int8_t sign = long_export .value < 0 ? -1 : 1 ;
310+ uint64_t abs_value = Py_ABS (long_export .value );
311+ uint64_t d = abs_value ;
312+ long l = 0 ;
282313
283- for (i = 0 ; i < n - 1 ; i ++ ) {
284- d = ob -> long_value .ob_digit [i ];
285- for (j = 0 ; j < PyLong_MARSHAL_RATIO ; j ++ ) {
314+ /* set l to number of base PyLong_MARSHAL_BASE digits */
315+ do {
316+ d >>= PyLong_MARSHAL_SHIFT ;
317+ l += sign ;
318+ } while (d );
319+ w_long (l , p );
320+
321+ d = abs_value ;
322+ do {
286323 w_short (d & PyLong_MARSHAL_MASK , p );
287324 d >>= PyLong_MARSHAL_SHIFT ;
288- }
289- assert ( d == 0 ) ;
325+ } while ( d );
326+ return ;
290327 }
291- d = ob -> long_value .ob_digit [n - 1 ];
292- do {
293- w_short (d & PyLong_MARSHAL_MASK , p );
294- d >>= PyLong_MARSHAL_SHIFT ;
295- } while (d != 0 );
328+
329+ const PyLongLayout * layout = PyLong_GetNativeLayout ();
330+ Py_ssize_t marshal_ratio = layout -> bits_per_digit /PyLong_MARSHAL_SHIFT ;
331+
332+ /* must be a multiple of PyLong_MARSHAL_SHIFT */
333+ assert (layout -> bits_per_digit % PyLong_MARSHAL_SHIFT == 0 );
334+ assert (layout -> bits_per_digit >= PyLong_MARSHAL_SHIFT );
335+
336+ /* other assumptions on PyLongObject internals */
337+ assert (layout -> bits_per_digit <= 32 );
338+ assert (layout -> digits_order == -1 );
339+ assert (layout -> digit_endianness == (PY_LITTLE_ENDIAN ? -1 : 1 ));
340+ assert (layout -> digit_size == 2 || layout -> digit_size == 4 );
341+
342+ if (layout -> digit_size == 4 ) {
343+ _r_digits32 (long_export .digits , long_export .ndigits ,
344+ long_export .negative , marshal_ratio , p );
345+ }
346+ else {
347+ _r_digits16 (long_export .digits , long_export .ndigits ,
348+ long_export .negative , marshal_ratio , p );
349+ }
350+ PyLong_FreeExport (& long_export );
296351}
297352
298353static void
@@ -875,17 +930,62 @@ r_long64(RFILE *p)
875930 1 /* signed */ );
876931}
877932
933+ #define _w_digits (bitsize ) \
934+ static int \
935+ _w_digits##bitsize(uint ## bitsize ## _t *digits, Py_ssize_t size, \
936+ Py_ssize_t marshal_ratio, \
937+ int shorts_in_top_digit, RFILE *p) \
938+ { \
939+ uint ## bitsize ## _t d; \
940+ \
941+ assert(size >= 1); \
942+ for (Py_ssize_t i = 0; i < size - 1; i++) { \
943+ d = 0; \
944+ for (Py_ssize_t j = 0; j < marshal_ratio; j++) { \
945+ int md = r_short(p); \
946+ if (md < 0 || md > PyLong_MARSHAL_BASE) { \
947+ goto bad_digit; \
948+ } \
949+ d += (uint ## bitsize ## _t)md << j*PyLong_MARSHAL_SHIFT; \
950+ } \
951+ digits[i] = d; \
952+ } \
953+ \
954+ d = 0; \
955+ for (Py_ssize_t j = 0; j < shorts_in_top_digit; j++) { \
956+ int md = r_short(p); \
957+ if (md < 0 || md > PyLong_MARSHAL_BASE) { \
958+ goto bad_digit; \
959+ } \
960+ /* topmost marshal digit should be nonzero */ \
961+ if (md == 0 && j == shorts_in_top_digit - 1 ) { \
962+ PyErr_SetString (PyExc_ValueError , \
963+ "bad marshal data (unnormalized long data)" ); \
964+ return -1 ; \
965+ } \
966+ d += (uint ## bitsize ## _t)md << j*PyLong_MARSHAL_SHIFT; \
967+ } \
968+ assert(!PyErr_Occurred()); \
969+ /* top digit should be nonzero, else the resulting PyLong won't be \
970+ normalized */ \
971+ digits [size - 1 ] = d ; \
972+ return 0 ; \
973+ \
974+ bad_digit : \
975+ if (!PyErr_Occurred ()) { \
976+ PyErr_SetString (PyExc_ValueError , \
977+ "bad marshal data (digit out of range in long)" ); \
978+ } \
979+ return -1 ; \
980+ }
981+ _w_digits (32 )
982+ _w_digits (16 )
983+ #undef _w_digits
984+
878985static PyObject *
879986r_PyLong (RFILE * p )
880987{
881- PyLongObject * ob ;
882- long n , size , i ;
883- int j , md , shorts_in_top_digit ;
884- digit d ;
885-
886- n = r_long (p );
887- if (n == 0 )
888- return (PyObject * )_PyLong_New (0 );
988+ long n = r_long (p );
889989 if (n == -1 && PyErr_Occurred ()) {
890990 return NULL ;
891991 }
@@ -895,51 +995,44 @@ r_PyLong(RFILE *p)
895995 return NULL ;
896996 }
897997
898- size = 1 + (Py_ABS (n ) - 1 ) / PyLong_MARSHAL_RATIO ;
899- shorts_in_top_digit = 1 + (Py_ABS (n ) - 1 ) % PyLong_MARSHAL_RATIO ;
900- ob = _PyLong_New (size );
901- if (ob == NULL )
902- return NULL ;
998+ const PyLongLayout * layout = PyLong_GetNativeLayout ();
999+ Py_ssize_t marshal_ratio = layout -> bits_per_digit /PyLong_MARSHAL_SHIFT ;
9031000
904- _PyLong_SetSignAndDigitCount (ob , n < 0 ? -1 : 1 , size );
1001+ /* must be a multiple of PyLong_MARSHAL_SHIFT */
1002+ assert (layout -> bits_per_digit % PyLong_MARSHAL_SHIFT == 0 );
1003+ assert (layout -> bits_per_digit >= PyLong_MARSHAL_SHIFT );
9051004
906- for (i = 0 ; i < size - 1 ; i ++ ) {
907- d = 0 ;
908- for (j = 0 ; j < PyLong_MARSHAL_RATIO ; j ++ ) {
909- md = r_short (p );
910- if (md < 0 || md > PyLong_MARSHAL_BASE )
911- goto bad_digit ;
912- d += (digit )md << j * PyLong_MARSHAL_SHIFT ;
913- }
914- ob -> long_value .ob_digit [i ] = d ;
1005+ /* other assumptions on PyLongObject internals */
1006+ assert (layout -> bits_per_digit <= 32 );
1007+ assert (layout -> digits_order == -1 );
1008+ assert (layout -> digit_endianness == (PY_LITTLE_ENDIAN ? -1 : 1 ));
1009+ assert (layout -> digit_size == 2 || layout -> digit_size == 4 );
1010+
1011+ Py_ssize_t size = 1 + (Py_ABS (n ) - 1 ) / marshal_ratio ;
1012+
1013+ assert (size >= 1 );
1014+
1015+ int shorts_in_top_digit = 1 + (Py_ABS (n ) - 1 ) % marshal_ratio ;
1016+ void * digits ;
1017+ PyLongWriter * writer = PyLongWriter_Create (n < 0 , size , & digits );
1018+
1019+ if (writer == NULL ) {
1020+ return NULL ;
9151021 }
9161022
917- d = 0 ;
918- for (j = 0 ; j < shorts_in_top_digit ; j ++ ) {
919- md = r_short (p );
920- if (md < 0 || md > PyLong_MARSHAL_BASE )
921- goto bad_digit ;
922- /* topmost marshal digit should be nonzero */
923- if (md == 0 && j == shorts_in_top_digit - 1 ) {
924- Py_DECREF (ob );
925- PyErr_SetString (PyExc_ValueError ,
926- "bad marshal data (unnormalized long data)" );
927- return NULL ;
928- }
929- d += (digit )md << j * PyLong_MARSHAL_SHIFT ;
1023+ int ret ;
1024+
1025+ if (layout -> digit_size == 4 ) {
1026+ ret = _w_digits32 (digits , size , marshal_ratio , shorts_in_top_digit , p );
9301027 }
931- assert (!PyErr_Occurred ());
932- /* top digit should be nonzero, else the resulting PyLong won't be
933- normalized */
934- ob -> long_value .ob_digit [size - 1 ] = d ;
935- return (PyObject * )ob ;
936- bad_digit :
937- Py_DECREF (ob );
938- if (!PyErr_Occurred ()) {
939- PyErr_SetString (PyExc_ValueError ,
940- "bad marshal data (digit out of range in long)" );
1028+ else {
1029+ ret = _w_digits16 (digits , size , marshal_ratio , shorts_in_top_digit , p );
1030+ }
1031+ if (ret < 0 ) {
1032+ PyLongWriter_Discard (writer );
1033+ return NULL ;
9411034 }
942- return NULL ;
1035+ return PyLongWriter_Finish ( writer ) ;
9431036}
9441037
9451038static double
0 commit comments