@@ -174,37 +174,38 @@ def __init__(self, value):
174174
175175# Pickling machinery
176176
177- class Pickler :
177+ class _Pickler :
178178
179179 def __init__ (self , file , protocol = None ):
180180 """This takes a binary file for writing a pickle data stream.
181181
182182 All protocols now read and write bytes.
183183
184184 The optional protocol argument tells the pickler to use the
185- given protocol; supported protocols are 0, 1, 2. The default
186- protocol is 2; it's been supported for many years now.
187-
188- Protocol 1 is more efficient than protocol 0; protocol 2 is
189- more efficient than protocol 1.
185+ given protocol; supported protocols are 0, 1, 2, 3. The default
186+ protocol is 3; a backward-incompatible protocol designed for
187+ Python 3.0.
190188
191189 Specifying a negative protocol version selects the highest
192190 protocol version supported. The higher the protocol used, the
193191 more recent the version of Python needed to read the pickle
194192 produced.
195193
196- The file parameter must have a write() method that accepts a single
197- string argument. It can thus be an open file object, a StringIO
198- object, or any other custom object that meets this interface.
199-
194+ The file argument must have a write() method that accepts a single
195+ bytes argument. It can thus be a file object opened for binary
196+ writing, a io.BytesIO instance, or any other custom object that
197+ meets this interface.
200198 """
201199 if protocol is None :
202200 protocol = DEFAULT_PROTOCOL
203201 if protocol < 0 :
204202 protocol = HIGHEST_PROTOCOL
205203 elif not 0 <= protocol <= HIGHEST_PROTOCOL :
206204 raise ValueError ("pickle protocol must be <= %d" % HIGHEST_PROTOCOL )
207- self .write = file .write
205+ try :
206+ self .write = file .write
207+ except AttributeError :
208+ raise TypeError ("file must have a 'write' attribute" )
208209 self .memo = {}
209210 self .proto = int (protocol )
210211 self .bin = protocol >= 1
@@ -270,10 +271,10 @@ def get(self, i, pack=struct.pack):
270271
271272 return GET + repr (i ).encode ("ascii" ) + b'\n '
272273
273- def save (self , obj ):
274+ def save (self , obj , save_persistent_id = True ):
274275 # Check for persistent id (defined by a subclass)
275276 pid = self .persistent_id (obj )
276- if pid :
277+ if pid is not None and save_persistent_id :
277278 self .save_pers (pid )
278279 return
279280
@@ -341,7 +342,7 @@ def persistent_id(self, obj):
341342 def save_pers (self , pid ):
342343 # Save a persistent id reference
343344 if self .bin :
344- self .save (pid )
345+ self .save (pid , save_persistent_id = False )
345346 self .write (BINPERSID )
346347 else :
347348 self .write (PERSID + str (pid ).encode ("ascii" ) + b'\n ' )
@@ -350,13 +351,13 @@ def save_reduce(self, func, args, state=None,
350351 listitems = None , dictitems = None , obj = None ):
351352 # This API is called by some subclasses
352353
353- # Assert that args is a tuple or None
354+ # Assert that args is a tuple
354355 if not isinstance (args , tuple ):
355- raise PicklingError ("args from reduce () should be a tuple" )
356+ raise PicklingError ("args from save_reduce () should be a tuple" )
356357
357358 # Assert that func is callable
358359 if not hasattr (func , '__call__' ):
359- raise PicklingError ("func from reduce should be callable" )
360+ raise PicklingError ("func from save_reduce() should be callable" )
360361
361362 save = self .save
362363 write = self .write
@@ -438,31 +439,6 @@ def save_bool(self, obj):
438439 self .write (obj and TRUE or FALSE )
439440 dispatch [bool ] = save_bool
440441
441- def save_int (self , obj , pack = struct .pack ):
442- if self .bin :
443- # If the int is small enough to fit in a signed 4-byte 2's-comp
444- # format, we can store it more efficiently than the general
445- # case.
446- # First one- and two-byte unsigned ints:
447- if obj >= 0 :
448- if obj <= 0xff :
449- self .write (BININT1 + bytes ([obj ]))
450- return
451- if obj <= 0xffff :
452- self .write (BININT2 + bytes ([obj & 0xff , obj >> 8 ]))
453- return
454- # Next check for 4-byte signed ints:
455- high_bits = obj >> 31 # note that Python shift sign-extends
456- if high_bits == 0 or high_bits == - 1 :
457- # All high bits are copies of bit 2**31, so the value
458- # fits in a 4-byte signed int.
459- self .write (BININT + pack ("<i" , obj ))
460- return
461- # Text pickle, or int too big to fit in signed 4-byte format.
462- self .write (INT + repr (obj ).encode ("ascii" ) + b'\n ' )
463- # XXX save_int is merged into save_long
464- # dispatch[int] = save_int
465-
466442 def save_long (self , obj , pack = struct .pack ):
467443 if self .bin :
468444 # If the int is small enough to fit in a signed 4-byte 2's-comp
@@ -503,7 +479,7 @@ def save_float(self, obj, pack=struct.pack):
503479
504480 def save_bytes (self , obj , pack = struct .pack ):
505481 if self .proto < 3 :
506- self .save_reduce (bytes , (list (obj ),))
482+ self .save_reduce (bytes , (list (obj ),), obj = obj )
507483 return
508484 n = len (obj )
509485 if n < 256 :
@@ -579,12 +555,6 @@ def save_tuple(self, obj):
579555
580556 dispatch [tuple ] = save_tuple
581557
582- # save_empty_tuple() isn't used by anything in Python 2.3. However, I
583- # found a Pickler subclass in Zope3 that calls it, so it's not harmless
584- # to remove it.
585- def save_empty_tuple (self , obj ):
586- self .write (EMPTY_TUPLE )
587-
588558 def save_list (self , obj ):
589559 write = self .write
590560
@@ -696,7 +666,7 @@ def save_global(self, obj, name=None, pack=struct.pack):
696666 module = whichmodule (obj , name )
697667
698668 try :
699- __import__ (module )
669+ __import__ (module , level = 0 )
700670 mod = sys .modules [module ]
701671 klass = getattr (mod , name )
702672 except (ImportError , KeyError , AttributeError ):
@@ -720,9 +690,19 @@ def save_global(self, obj, name=None, pack=struct.pack):
720690 else :
721691 write (EXT4 + pack ("<i" , code ))
722692 return
693+ # Non-ASCII identifiers are supported only with protocols >= 3.
694+ if self .proto >= 3 :
695+ write (GLOBAL + bytes (module , "utf-8" ) + b'\n ' +
696+ bytes (name , "utf-8" ) + b'\n ' )
697+ else :
698+ try :
699+ write (GLOBAL + bytes (module , "ascii" ) + b'\n ' +
700+ bytes (name , "ascii" ) + b'\n ' )
701+ except UnicodeEncodeError :
702+ raise PicklingError (
703+ "can't pickle global identifier '%s.%s' using "
704+ "pickle protocol %i" % (module , name , self .proto ))
723705
724- write (GLOBAL + bytes (module , "utf-8" ) + b'\n ' +
725- bytes (name , "utf-8" ) + b'\n ' )
726706 self .memoize (obj )
727707
728708 dispatch [FunctionType ] = save_global
@@ -781,7 +761,7 @@ def whichmodule(func, funcname):
781761
782762# Unpickling machinery
783763
784- class Unpickler :
764+ class _Unpickler :
785765
786766 def __init__ (self , file , * , encoding = "ASCII" , errors = "strict" ):
787767 """This takes a binary file for reading a pickle data stream.
@@ -841,6 +821,9 @@ def marker(self):
841821 while stack [k ] is not mark : k = k - 1
842822 return k
843823
824+ def persistent_load (self , pid ):
825+ raise UnpickingError ("unsupported persistent id encountered" )
826+
844827 dispatch = {}
845828
846829 def load_proto (self ):
@@ -850,7 +833,7 @@ def load_proto(self):
850833 dispatch [PROTO [0 ]] = load_proto
851834
852835 def load_persid (self ):
853- pid = self .readline ()[:- 1 ]
836+ pid = self .readline ()[:- 1 ]. decode ( "ascii" )
854837 self .append (self .persistent_load (pid ))
855838 dispatch [PERSID [0 ]] = load_persid
856839
@@ -879,9 +862,9 @@ def load_int(self):
879862 val = True
880863 else :
881864 try :
882- val = int (data )
865+ val = int (data , 0 )
883866 except ValueError :
884- val = int (data )
867+ val = int (data , 0 )
885868 self .append (val )
886869 dispatch [INT [0 ]] = load_int
887870
@@ -933,7 +916,8 @@ def load_string(self):
933916 break
934917 else :
935918 raise ValueError ("insecure string pickle: %r" % orig )
936- self .append (codecs .escape_decode (rep )[0 ])
919+ self .append (codecs .escape_decode (rep )[0 ]
920+ .decode (self .encoding , self .errors ))
937921 dispatch [STRING [0 ]] = load_string
938922
939923 def load_binstring (self ):
@@ -975,7 +959,7 @@ def load_tuple(self):
975959 dispatch [TUPLE [0 ]] = load_tuple
976960
977961 def load_empty_tuple (self ):
978- self .stack . append (())
962+ self .append (())
979963 dispatch [EMPTY_TUPLE [0 ]] = load_empty_tuple
980964
981965 def load_tuple1 (self ):
@@ -991,11 +975,11 @@ def load_tuple3(self):
991975 dispatch [TUPLE3 [0 ]] = load_tuple3
992976
993977 def load_empty_list (self ):
994- self .stack . append ([])
978+ self .append ([])
995979 dispatch [EMPTY_LIST [0 ]] = load_empty_list
996980
997981 def load_empty_dictionary (self ):
998- self .stack . append ({})
982+ self .append ({})
999983 dispatch [EMPTY_DICT [0 ]] = load_empty_dictionary
1000984
1001985 def load_list (self ):
@@ -1022,13 +1006,13 @@ def load_dict(self):
10221006 def _instantiate (self , klass , k ):
10231007 args = tuple (self .stack [k + 1 :])
10241008 del self .stack [k :]
1025- instantiated = 0
1009+ instantiated = False
10261010 if (not args and
10271011 isinstance (klass , type ) and
10281012 not hasattr (klass , "__getinitargs__" )):
10291013 value = _EmptyClass ()
10301014 value .__class__ = klass
1031- instantiated = 1
1015+ instantiated = True
10321016 if not instantiated :
10331017 try :
10341018 value = klass (* args )
@@ -1038,8 +1022,8 @@ def _instantiate(self, klass, k):
10381022 self .append (value )
10391023
10401024 def load_inst (self ):
1041- module = self .readline ()[:- 1 ]
1042- name = self .readline ()[:- 1 ]
1025+ module = self .readline ()[:- 1 ]. decode ( "ascii" )
1026+ name = self .readline ()[:- 1 ]. decode ( "ascii" )
10431027 klass = self .find_class (module , name )
10441028 self ._instantiate (klass , self .marker ())
10451029 dispatch [INST [0 ]] = load_inst
@@ -1059,8 +1043,8 @@ def load_newobj(self):
10591043 dispatch [NEWOBJ [0 ]] = load_newobj
10601044
10611045 def load_global (self ):
1062- module = self .readline ()[:- 1 ]
1063- name = self .readline ()[:- 1 ]
1046+ module = self .readline ()[:- 1 ]. decode ( "utf-8" )
1047+ name = self .readline ()[:- 1 ]. decode ( "utf-8" )
10641048 klass = self .find_class (module , name )
10651049 self .append (klass )
10661050 dispatch [GLOBAL [0 ]] = load_global
@@ -1095,11 +1079,7 @@ def get_extension(self, code):
10951079
10961080 def find_class (self , module , name ):
10971081 # Subclasses may override this
1098- if isinstance (module , bytes_types ):
1099- module = module .decode ("utf-8" )
1100- if isinstance (name , bytes_types ):
1101- name = name .decode ("utf-8" )
1102- __import__ (module )
1082+ __import__ (module , level = 0 )
11031083 mod = sys .modules [module ]
11041084 klass = getattr (mod , name )
11051085 return klass
@@ -1131,31 +1111,33 @@ def load_dup(self):
11311111 dispatch [DUP [0 ]] = load_dup
11321112
11331113 def load_get (self ):
1134- self .append (self .memo [self .readline ()[:- 1 ].decode ("ascii" )])
1114+ i = int (self .readline ()[:- 1 ])
1115+ self .append (self .memo [i ])
11351116 dispatch [GET [0 ]] = load_get
11361117
11371118 def load_binget (self ):
1138- i = ord ( self .read (1 ))
1139- self .append (self .memo [repr ( i ) ])
1119+ i = self .read (1 )[ 0 ]
1120+ self .append (self .memo [i ])
11401121 dispatch [BINGET [0 ]] = load_binget
11411122
11421123 def load_long_binget (self ):
11431124 i = mloads (b'i' + self .read (4 ))
1144- self .append (self .memo [repr ( i ) ])
1125+ self .append (self .memo [i ])
11451126 dispatch [LONG_BINGET [0 ]] = load_long_binget
11461127
11471128 def load_put (self ):
1148- self .memo [self .readline ()[:- 1 ].decode ("ascii" )] = self .stack [- 1 ]
1129+ i = int (self .readline ()[:- 1 ])
1130+ self .memo [i ] = self .stack [- 1 ]
11491131 dispatch [PUT [0 ]] = load_put
11501132
11511133 def load_binput (self ):
1152- i = ord ( self .read (1 ))
1153- self .memo [repr ( i ) ] = self .stack [- 1 ]
1134+ i = self .read (1 )[ 0 ]
1135+ self .memo [i ] = self .stack [- 1 ]
11541136 dispatch [BINPUT [0 ]] = load_binput
11551137
11561138 def load_long_binput (self ):
11571139 i = mloads (b'i' + self .read (4 ))
1158- self .memo [repr ( i ) ] = self .stack [- 1 ]
1140+ self .memo [i ] = self .stack [- 1 ]
11591141 dispatch [LONG_BINPUT [0 ]] = load_long_binput
11601142
11611143 def load_append (self ):
@@ -1321,6 +1303,12 @@ def decode_long(data):
13211303 n -= 1 << (nbytes * 8 )
13221304 return n
13231305
1306+ # Use the faster _pickle if possible
1307+ try :
1308+ from _pickle import *
1309+ except ImportError :
1310+ Pickler , Unpickler = _Pickler , _Unpickler
1311+
13241312# Shorthands
13251313
13261314def dump (obj , file , protocol = None ):
@@ -1333,14 +1321,14 @@ def dumps(obj, protocol=None):
13331321 assert isinstance (res , bytes_types )
13341322 return res
13351323
1336- def load (file ):
1337- return Unpickler (file ).load ()
1324+ def load (file , * , encoding = "ASCII" , errors = "strict" ):
1325+ return Unpickler (file , encoding = encoding , errors = errors ).load ()
13381326
1339- def loads (s ):
1327+ def loads (s , * , encoding = "ASCII" , errors = "strict" ):
13401328 if isinstance (s , str ):
13411329 raise TypeError ("Can't load pickle from unicode string" )
13421330 file = io .BytesIO (s )
1343- return Unpickler (file ).load ()
1331+ return Unpickler (file , encoding = encoding , errors = errors ).load ()
13441332
13451333# Doctest
13461334
0 commit comments