@@ -59,9 +59,11 @@ def __stream__(self) -> Iterator[_T]:
5959 if sse .data .startswith ("[DONE]" ):
6060 break
6161
62- if sse .event is not None and not sse .event .startswith ("thread." ):
62+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
63+ if sse .event and sse .event .startswith ("thread." ):
6364 data = sse .json ()
64- if is_mapping (data ) and data .get ("error" ):
65+
66+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
6567 message = None
6668 error = data .get ("error" )
6769 if is_mapping (error ):
@@ -75,12 +77,10 @@ def __stream__(self) -> Iterator[_T]:
7577 body = data ["error" ],
7678 )
7779
78- yield process_data (data = data , cast_to = cast_to , response = response )
79-
80+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
8081 else :
8182 data = sse .json ()
82-
83- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
83+ if is_mapping (data ) and data .get ("error" ):
8484 message = None
8585 error = data .get ("error" )
8686 if is_mapping (error ):
@@ -93,8 +93,8 @@ def __stream__(self) -> Iterator[_T]:
9393 request = self .response .request ,
9494 body = data ["error" ],
9595 )
96- # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
97- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
96+
97+ yield process_data (data = data , cast_to = cast_to , response = response )
9898
9999 # Ensure the entire stream is consumed
100100 for _sse in iterator :
@@ -161,9 +161,11 @@ async def __stream__(self) -> AsyncIterator[_T]:
161161 if sse .data .startswith ("[DONE]" ):
162162 break
163163
164- if sse .event is None or sse .event .startswith ("response." ) or sse .event .startswith ("transcript." ):
164+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
165+ if sse .event and sse .event .startswith ("thread." ):
165166 data = sse .json ()
166- if is_mapping (data ) and data .get ("error" ):
167+
168+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
167169 message = None
168170 error = data .get ("error" )
169171 if is_mapping (error ):
@@ -177,12 +179,10 @@ async def __stream__(self) -> AsyncIterator[_T]:
177179 body = data ["error" ],
178180 )
179181
180- yield process_data (data = data , cast_to = cast_to , response = response )
181-
182+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
182183 else :
183184 data = sse .json ()
184-
185- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
185+ if is_mapping (data ) and data .get ("error" ):
186186 message = None
187187 error = data .get ("error" )
188188 if is_mapping (error ):
@@ -196,7 +196,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
196196 body = data ["error" ],
197197 )
198198
199- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
199+ yield process_data (data = data , cast_to = cast_to , response = response )
200200
201201 # Ensure the entire stream is consumed
202202 async for _sse in iterator :
0 commit comments