@@ -116,6 +116,7 @@ def __init__(self,
116116 if not callable (probability_fn ):
117117 raise TypeError ("probability_fn must be callable, saw type: %s" %
118118 type (probability_fn ).__name__ )
119+ self .default_probability_fn = probability_fn
119120 self .probability_fn = probability_fn
120121
121122 self .keys = None
@@ -226,7 +227,7 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs):
226227 else :
227228 memory , memory_sequence_length = inputs , None
228229 memory_mask = mask
229- self ._setup_memory (memory , memory_sequence_length , memory_mask )
230+ self .setup_memory (memory , memory_sequence_length , memory_mask )
230231 # We force the self.built to false here since only memory is,
231232 # initialized but the real query/state has not been call() yet. The
232233 # layer should be build and call again.
@@ -248,10 +249,10 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs):
248249 query , state = inputs [0 ], inputs [1 ]
249250 return self ._calculate_attention (query , state )
250251
251- def _setup_memory (self ,
252- memory ,
253- memory_sequence_length = None ,
254- memory_mask = None ):
252+ def setup_memory (self ,
253+ memory ,
254+ memory_sequence_length = None ,
255+ memory_mask = None ):
255256 """Pre-process the memory before actually query the memory.
256257
257258 This should only be called once at the first invocation of call().
@@ -266,9 +267,6 @@ def _setup_memory(self,
266267 max_time]`. For any value equal to False, the corresponding value
267268 in memory should be ignored.
268269 """
269- if self ._memory_initialized :
270- raise ValueError (
271- "The memory for the attention has already been setup." )
272270 if memory_sequence_length is not None and memory_mask is not None :
273271 raise ValueError (
274272 "memory_sequence_length and memory_mask cannot be "
@@ -293,7 +291,7 @@ def _setup_memory(self,
293291 self ._alignments_size = (tf .compat .dimension_value (
294292 self .keys .shape [1 ]) or tf .shape (self .keys )[1 ])
295293 if memory_mask is not None or memory_sequence_length is not None :
296- unwrapped_probability_fn = self .probability_fn
294+ unwrapped_probability_fn = self .default_probability_fn
297295
298296 def _mask_probability_fn (score , prev ):
299297 return unwrapped_probability_fn (
@@ -505,7 +503,7 @@ class LuongAttention(_BaseAttentionMechanism):
505503
506504 def __init__ (self ,
507505 units ,
508- memory ,
506+ memory = None ,
509507 memory_sequence_length = None ,
510508 scale = False ,
511509 probability_fn = "softmax" ,
@@ -671,7 +669,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
671669
672670 def __init__ (self ,
673671 units ,
674- memory ,
672+ memory = None ,
675673 memory_sequence_length = None ,
676674 normalize = False ,
677675 probability_fn = "softmax" ,
@@ -1013,7 +1011,7 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
10131011
10141012 def __init__ (self ,
10151013 units ,
1016- memory ,
1014+ memory = None ,
10171015 memory_sequence_length = None ,
10181016 normalize = False ,
10191017 sigmoid_noise = 0. ,
@@ -1186,7 +1184,7 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
11861184
11871185 def __init__ (self ,
11881186 units ,
1189- memory ,
1187+ memory = None ,
11901188 memory_sequence_length = None ,
11911189 scale = False ,
11921190 sigmoid_noise = 0. ,
0 commit comments