@@ -59,14 +59,11 @@ def __stream__(self) -> Iterator[_T]:
5959 if sse .data .startswith ("[DONE]" ):
6060 break
6161
62- if sse .event is None or (
63- sse .event .startswith ("response." ) or
64- sse .event .startswith ("transcript." ) or
65- sse .event .startswith ("image_edit." ) or
66- sse .event .startswith ("image_generation." )
67- ):
62+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
63+ if sse .event and sse .event .startswith ("thread." ):
6864 data = sse .json ()
69- if is_mapping (data ) and data .get ("error" ):
65+
66+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
7067 message = None
7168 error = data .get ("error" )
7269 if is_mapping (error ):
@@ -80,12 +77,10 @@ def __stream__(self) -> Iterator[_T]:
8077 body = data ["error" ],
8178 )
8279
83- yield process_data (data = data , cast_to = cast_to , response = response )
84-
80+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
8581 else :
8682 data = sse .json ()
87-
88- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
83+ if is_mapping (data ) and data .get ("error" ):
8984 message = None
9085 error = data .get ("error" )
9186 if is_mapping (error ):
@@ -99,7 +94,7 @@ def __stream__(self) -> Iterator[_T]:
9994 body = data ["error" ],
10095 )
10196
102- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
97+ yield process_data (data = data , cast_to = cast_to , response = response )
10398
10499 # Ensure the entire stream is consumed
105100 for _sse in iterator :
@@ -166,9 +161,11 @@ async def __stream__(self) -> AsyncIterator[_T]:
166161 if sse .data .startswith ("[DONE]" ):
167162 break
168163
169- if sse .event is None or sse .event .startswith ("response." ) or sse .event .startswith ("transcript." ):
164+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
165+ if sse .event and sse .event .startswith ("thread." ):
170166 data = sse .json ()
171- if is_mapping (data ) and data .get ("error" ):
167+
168+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
172169 message = None
173170 error = data .get ("error" )
174171 if is_mapping (error ):
@@ -182,12 +179,10 @@ async def __stream__(self) -> AsyncIterator[_T]:
182179 body = data ["error" ],
183180 )
184181
185- yield process_data (data = data , cast_to = cast_to , response = response )
186-
182+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
187183 else :
188184 data = sse .json ()
189-
190- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
185+ if is_mapping (data ) and data .get ("error" ):
191186 message = None
192187 error = data .get ("error" )
193188 if is_mapping (error ):
@@ -201,7 +196,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
201196 body = data ["error" ],
202197 )
203198
204- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
199+ yield process_data (data = data , cast_to = cast_to , response = response )
205200
206201 # Ensure the entire stream is consumed
207202 async for _sse in iterator :
0 commit comments