1212from lightning_app .storage .drive import _maybe_create_drive , Drive
1313from lightning_app .storage .payload import Payload
1414from lightning_app .utilities .app_helpers import _is_json_serializable , _LightningAppRef
15- from lightning_app .utilities .component import _sanitize_state
16- from lightning_app .utilities .enum import make_status , WorkFailureReasons , WorkStageStatus , WorkStatus , WorkStopReasons
15+ from lightning_app .utilities .component import _is_flow_context , _sanitize_state
16+ from lightning_app .utilities .enum import (
17+ CacheCallsKeys ,
18+ make_status ,
19+ WorkFailureReasons ,
20+ WorkStageStatus ,
21+ WorkStatus ,
22+ WorkStopReasons ,
23+ )
1724from lightning_app .utilities .exceptions import LightningWorkException
1825from lightning_app .utilities .introspection import _is_init_context
1926from lightning_app .utilities .network import find_free_network_port
@@ -107,7 +114,21 @@ def __init__(
107114 # setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
108115 self ._setattr_replacement : Optional [Callable [[str , Any ], None ]] = None
109116 self ._name = ""
110- self ._calls = {"latest_call_hash" : None }
117+ # The ``self._calls`` is used to track whether the run
118+ # method with a given set of input arguments has already been called.
119+ # Example of its usage:
120+ # {
121+ # 'latest_call_hash': '167fe2e',
122+ # '167fe2e': {
123+ # 'statuses': [
124+ # {'stage': 'pending', 'timestamp': 1659433519.851271},
125+ # {'stage': 'running', 'timestamp': 1659433519.956482},
126+ # {'stage': 'stopped', 'timestamp': 1659433520.055768}]}
127+ # ]
128+ # },
129+ # ...
130+ # }
131+ self ._calls = {CacheCallsKeys .LATEST_CALL_HASH : None }
111132 self ._changes = {}
112133 self ._raise_exception = raise_exception
113134 self ._paths = {}
@@ -215,22 +236,22 @@ def status(self) -> WorkStatus:
215236
216237 All statuses are stored in the state.
217238 """
218- call_hash = self ._calls ["latest_call_hash" ]
219- if call_hash :
239+ call_hash = self ._calls [CacheCallsKeys . LATEST_CALL_HASH ]
240+ if call_hash in self . _calls :
220241 statuses = self ._calls [call_hash ]["statuses" ]
221242 # deltas aren't necessarily coming in the expected order.
222243 statuses = sorted (statuses , key = lambda x : x ["timestamp" ])
223244 latest_status = statuses [- 1 ]
224- if latest_status [ "reason" ] == WorkFailureReasons .TIMEOUT :
245+ if latest_status . get ( "reason" ) == WorkFailureReasons .TIMEOUT :
225246 return self ._aggregate_status_timeout (statuses )
226247 return WorkStatus (** latest_status )
227248 return WorkStatus (stage = WorkStageStatus .NOT_STARTED , timestamp = time .time ())
228249
229250 @property
230251 def statuses (self ) -> List [WorkStatus ]:
231252 """Return all the status of the work."""
232- call_hash = self ._calls ["latest_call_hash" ]
233- if call_hash :
253+ call_hash = self ._calls [CacheCallsKeys . LATEST_CALL_HASH ]
254+ if call_hash in self . _calls :
234255 statuses = self ._calls [call_hash ]["statuses" ]
235256 # deltas aren't necessarily coming in the expected order.
236257 statuses = sorted (statuses , key = lambda x : x ["timestamp" ])
@@ -398,10 +419,13 @@ def __getattr__(self, item):
398419 return path
399420 return self .__getattribute__ (item )
400421
401- def _call_hash (self , fn , args , kwargs ):
422+ def _call_hash (self , fn , args , kwargs ) -> str :
402423 hash_args = args [1 :] if len (args ) > 0 and args [0 ] == self else args
403424 call_obj = {"args" : hash_args , "kwargs" : kwargs }
404- return f"{ fn .__name__ } :{ DeepHash (call_obj )[call_obj ]} "
425+ # Note: Generate a hash as 167fe2e.
426+ # Seven was selected after checking upon Github default SHA length
427+ # and to minimize hidden state size.
428+ return str (DeepHash (call_obj )[call_obj ])[:7 ]
405429
406430 def _wrap_run_for_caching (self , fn ):
407431 @wraps (fn )
@@ -415,11 +439,11 @@ def new_fn(*args, **kwargs):
415439 entry = self ._calls [call_hash ]
416440 return entry ["ret" ]
417441
418- self ._calls [call_hash ] = {"name" : fn . __name__ , "call_hash" : call_hash }
442+ self ._calls [call_hash ] = {}
419443
420444 result = fn (* args , ** kwargs )
421445
422- self ._calls [call_hash ] = {"name" : fn . __name__ , "call_hash" : call_hash , " ret" : result }
446+ self ._calls [call_hash ] = {"ret" : result }
423447
424448 return result
425449
@@ -457,8 +481,40 @@ def set_state(self, provided_state):
457481 if isinstance (v , Dict ):
458482 v = _maybe_create_drive (self .name , v )
459483 setattr (self , k , v )
484+
460485 self ._changes = provided_state ["changes" ]
461- self ._calls .update (provided_state ["calls" ])
486+
487+ # Note, this is handled by the flow only.
488+ if _is_flow_context ():
489+ self ._cleanup_calls (provided_state ["calls" ])
490+
491+ self ._calls = provided_state ["calls" ]
492+
493+ @staticmethod
494+ def _cleanup_calls (calls : Dict [str , Any ]):
495+ # 1: Collect all the in_progress call hashes
496+ in_progress_call_hash = [k for k in list (calls ) if k not in (CacheCallsKeys .LATEST_CALL_HASH )]
497+
498+ for call_hash in in_progress_call_hash :
499+ if "statuses" not in calls [call_hash ]:
500+ continue
501+
502+ # 2: Filter the statuses by timestamp
503+ statuses = sorted (calls [call_hash ]["statuses" ], key = lambda x : x ["timestamp" ])
504+
505+ # If the latest status is succeeded, then drop everything before.
506+ if statuses [- 1 ]["stage" ] == WorkStageStatus .SUCCEEDED :
507+ status = statuses [- 1 ]
508+ status ["timestamp" ] = int (status ["timestamp" ])
509+ calls [call_hash ]["statuses" ] = [status ]
510+ else :
511+ # TODO: Some status are being duplicated,
512+ # this seems related to the StateObserver.
513+ final_statuses = []
514+ for status in statuses :
515+ if status not in final_statuses :
516+ final_statuses .append (status )
517+ calls [call_hash ]["statuses" ] = final_statuses
462518
463519 @abc .abstractmethod
464520 def run (self , * args , ** kwargs ):
@@ -479,7 +535,7 @@ def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus:
479535 if succeeded_statuses :
480536 succeed_status_id = succeeded_statuses [- 1 ] + 1
481537 statuses = statuses [succeed_status_id :]
482- timeout_statuses = [status for status in statuses if status [ "reason" ] == WorkFailureReasons .TIMEOUT ]
538+ timeout_statuses = [status for status in statuses if status . get ( "reason" ) == WorkFailureReasons .TIMEOUT ]
483539 assert statuses [0 ]["stage" ] == WorkStageStatus .PENDING
484540 status = {** timeout_statuses [- 1 ], "timestamp" : statuses [0 ]["timestamp" ]}
485541 return WorkStatus (** status , count = len (timeout_statuses ))
@@ -501,9 +557,8 @@ def stop(self):
501557 )
502558 if self .status .stage == WorkStageStatus .STOPPED :
503559 return
504- latest_hash = self ._calls ["latest_call_hash" ]
505- self ._calls [latest_hash ]["statuses" ].append (
506- make_status (WorkStageStatus .STOPPED , reason = WorkStopReasons .PENDING )
507- )
560+ latest_hash = self ._calls [CacheCallsKeys .LATEST_CALL_HASH ]
561+ stop_status = make_status (WorkStageStatus .STOPPED , reason = WorkStopReasons .PENDING )
562+ self ._calls [latest_hash ]["statuses" ].append (stop_status )
508563 app = _LightningAppRef ().get_current ()
509564 self ._backend .stop_work (app , self )
0 commit comments