@@ -650,6 +650,7 @@ def __init__(
650650 # The sequence of model-generated RNGs
651651 self .rng_seq = []
652652 self ._initial_values = {}
653+ self ._initial_point_cache = {}
653654
654655 if self .parent is not None :
655656 self .named_vars = treedict (parent = self .parent .named_vars )
@@ -926,15 +927,28 @@ def cont_vars(self):
926927 def test_point (self ) -> Dict [str , np .ndarray ]:
927928 """Deprecated alias for `Model.initial_point`."""
928929 warnings .warn (
929- "`Model.test_point` has been deprecated. Use `Model.initial_point` instead ." ,
930+ "`Model.test_point` has been deprecated. Use `Model.initial_point` or `Model.recompute_initial_point()` ." ,
930931 DeprecationWarning ,
931932 )
932933 return self .initial_point
933934
934935 @property
935936 def initial_point (self ) -> Dict [str , np .ndarray ]:
936- """Maps names of variables to initial values."""
937- return Point (list (self .initial_values .items ()), model = self )
937+ """Maps free variable names to transformed, numeric initial values."""
938+ if set (self ._initial_point_cache ) != {get_var_name (k ) for k in self .initial_values }:
939+ return self .recompute_initial_point ()
940+ return self ._initial_point_cache
941+
942+ def recompute_initial_point (self ) -> Dict [str , np .ndarray ]:
943+ """Recomputes numeric initial values for all free model variables.
944+
945+ Returns
946+ -------
947+ initial_point : dict
948+ Maps free variable names to transformed, numeric initial values.
949+ """
950+ self ._initial_point_cache = Point (list (self .initial_values .items ()), model = self )
951+ return self ._initial_point_cache
938952
939953 @property
940954 def initial_values (self ) -> Dict [TensorVariable , np .ndarray ]:
0 commit comments