40
40
# Library worker import reloaded in init and reload request
41
41
_library_worker = None
42
42
43
+ # Thread-local invocation ID registry for efficient lookup
44
+ _thread_invocation_registry : typing .Dict [int , str ] = {}
45
+ _registry_lock = threading .Lock ()
46
+
47
+ # Global current invocation tracker (as a fallback)
48
+ _current_invocation_id : Optional [str ] = None
49
+ _current_invocation_lock = threading .Lock ()
50
+
43
51
44
52
class ContextEnabledTask (asyncio .Task ):
45
53
AZURE_INVOCATION_ID = '__azure_function_invocation_id__'
@@ -61,16 +69,63 @@ def set_azure_invocation_id(self, invocation_id: str) -> None:
61
69
_invocation_id_local = threading .local ()
62
70
63
71
72
+ def set_thread_invocation_id (thread_id : int , invocation_id : str ) -> None :
73
+ """Set the invocation ID for a specific thread"""
74
+ with _registry_lock :
75
+ _thread_invocation_registry [thread_id ] = invocation_id
76
+
77
+
78
+ def clear_thread_invocation_id (thread_id : int ) -> None :
79
+ """Clear the invocation ID for a specific thread"""
80
+ with _registry_lock :
81
+ _thread_invocation_registry .pop (thread_id , None )
82
+
83
+
84
+ def get_thread_invocation_id (thread_id : int ) -> Optional [str ]:
85
+ """Get the invocation ID for a specific thread"""
86
+ with _registry_lock :
87
+ return _thread_invocation_registry .get (thread_id )
88
+
89
+
90
+ def set_current_invocation_id (invocation_id : str ) -> None :
91
+ """Set the global current invocation ID"""
92
+ global _current_invocation_id
93
+ with _current_invocation_lock :
94
+ _current_invocation_id = invocation_id
95
+
96
+
97
+ def get_global_current_invocation_id () -> Optional [str ]:
98
+ """Get the global current invocation ID"""
99
+ with _current_invocation_lock :
100
+ return _current_invocation_id
101
+
102
+
64
103
def get_current_invocation_id () -> Optional [Any ]:
65
- loop = asyncio ._get_running_loop ()
66
- if loop is not None :
67
- current_task = asyncio .current_task (loop )
68
- if current_task is not None :
69
- task_invocation_id = getattr (current_task ,
70
- ContextEnabledTask .AZURE_INVOCATION_ID ,
71
- None )
72
- if task_invocation_id is not None :
73
- return task_invocation_id
104
+ # Check global current invocation first (most up-to-date)
105
+ global_invocation_id = get_global_current_invocation_id ()
106
+ if global_invocation_id is not None :
107
+ return global_invocation_id
108
+
109
+ # Check asyncio task context
110
+ try :
111
+ loop = asyncio ._get_running_loop ()
112
+ if loop is not None :
113
+ current_task = asyncio .current_task (loop )
114
+ if current_task is not None :
115
+ task_invocation_id = getattr (current_task ,
116
+ ContextEnabledTask .AZURE_INVOCATION_ID ,
117
+ None )
118
+ if task_invocation_id is not None :
119
+ return task_invocation_id
120
+ except RuntimeError :
121
+ # No event loop running
122
+ pass
123
+
124
+ # Check the thread-local invocation ID registry
125
+ current_thread_id = threading .get_ident ()
126
+ thread_invocation_id = get_thread_invocation_id (current_thread_id )
127
+ if thread_invocation_id is not None :
128
+ return thread_invocation_id
74
129
75
130
return getattr (_invocation_id_local , 'invocation_id' , None )
76
131
@@ -516,14 +571,32 @@ async def _handle__invocation_request(self, request):
516
571
'invocation_id: %s, worker_id: %s' ,
517
572
self .request_id , function_id , invocation_id , self .worker_id )
518
573
519
- invocation_request = WorkerRequest (name = "FunctionInvocationRequest" ,
520
- request = request ,
521
- properties = {
522
- "threadpool" : self ._sync_call_tp })
523
- invocation_response = await (
524
- _library_worker .invocation_request ( # type: ignore[union-attr]
525
- invocation_request ))
574
+ # Set the global current invocation ID first (for all threads to access)
575
+ set_current_invocation_id (invocation_id )
526
576
527
- return protos .StreamingMessage (
528
- request_id = self .request_id ,
529
- invocation_response = invocation_response )
577
+ # Set the current `invocation_id` to the current task so
578
+ # that our logging handler can find it.
579
+ current_task = asyncio .current_task ()
580
+ if current_task is not None and isinstance (current_task , ContextEnabledTask ):
581
+ current_task .set_azure_invocation_id (invocation_id )
582
+
583
+ # Register the invocation ID for the current thread
584
+ current_thread_id = threading .get_ident ()
585
+ set_thread_invocation_id (current_thread_id , invocation_id )
586
+
587
+ try :
588
+ invocation_request = WorkerRequest (name = "FunctionInvocationRequest" ,
589
+ request = request ,
590
+ properties = {
591
+ "threadpool" : self ._sync_call_tp })
592
+ invocation_response = await (
593
+ _library_worker .invocation_request ( # type: ignore[union-attr]
594
+ invocation_request ))
595
+
596
+ return protos .StreamingMessage (
597
+ request_id = self .request_id ,
598
+ invocation_response = invocation_response )
599
+ except Exception :
600
+ # Clear thread registry on exception to prevent stale IDs
601
+ clear_thread_invocation_id (current_thread_id )
602
+ raise
0 commit comments