From 7395868b9802f8c7807b66e686b99a5267d5eb12 Mon Sep 17 00:00:00 2001 From: Garima Dhanania Date: Mon, 8 Sep 2025 14:58:18 -0700 Subject: [PATCH 1/4] added speaker diarization tests --- .../resources/audio/transcriptions.py | 6 +- src/together/types/audio_speech.py | 14 ++ .../resources/test_transcriptions.py | 139 ++++++++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) diff --git a/src/together/resources/audio/transcriptions.py b/src/together/resources/audio/transcriptions.py index 766d417..49b83dd 100644 --- a/src/together/resources/audio/transcriptions.py +++ b/src/together/resources/audio/transcriptions.py @@ -131,7 +131,8 @@ def create( response_format == "verbose_json" or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON ): - return AudioTranscriptionVerboseResponse(**response.data) + # Create response with model validation that preserves extra fields + return AudioTranscriptionVerboseResponse.model_validate(response.data) else: return AudioTranscriptionResponse(**response.data) @@ -261,6 +262,7 @@ async def create( response_format == "verbose_json" or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON ): - return AudioTranscriptionVerboseResponse(**response.data) + # Create response with model validation that preserves extra fields + return AudioTranscriptionVerboseResponse.model_validate(response.data) else: return AudioTranscriptionResponse(**response.data) diff --git a/src/together/types/audio_speech.py b/src/together/types/audio_speech.py index b3c110f..09e7ddc 100644 --- a/src/together/types/audio_speech.py +++ b/src/together/types/audio_speech.py @@ -158,6 +158,17 @@ class AudioTranscriptionWord(BaseModel): word: str start: float end: float + id: Optional[int] = None + speaker_id: Optional[str] = None + + +class AudioSpeakerSegment(BaseModel): + id: int + speaker_id: str + start: float + end: float + text: str + words: List[AudioTranscriptionWord] class AudioTranscriptionResponse(BaseModel): @@ -165,11 +176,14 @@ class AudioTranscriptionResponse(BaseModel): class AudioTranscriptionVerboseResponse(BaseModel): + model_config = ConfigDict(extra="allow") + language: Optional[str] = None duration: Optional[float] = None text: str segments: Optional[List[AudioTranscriptionSegment]] = None words: Optional[List[AudioTranscriptionWord]] = None + speaker_segments: Optional[List[AudioSpeakerSegment]] = None class AudioTranslationResponse(BaseModel): diff --git a/tests/integration/resources/test_transcriptions.py b/tests/integration/resources/test_transcriptions.py index 3852ebe..fbc5a91 100644 --- a/tests/integration/resources/test_transcriptions.py +++ b/tests/integration/resources/test_transcriptions.py @@ -9,6 +9,52 @@ ) +def validate_diarization_response(response_dict): + """ + Helper function to validate diarization response structure + """ + # Validate top-level speaker_segments field + assert "speaker_segments" in response_dict + assert isinstance(response_dict["speaker_segments"], list) + assert len(response_dict["speaker_segments"]) > 0 + + # Validate each speaker segment structure + for segment in response_dict["speaker_segments"]: + assert "text" in segment + assert "id" in segment + assert "speaker_id" in segment + assert "start" in segment + assert "end" in segment + assert "words" in segment + + # Validate nested words in speaker segments + assert isinstance(segment["words"], list) + for word in segment["words"]: + assert "id" in word + assert "word" in word + assert "start" in word + assert "end" in word + assert "speaker_id" in word + + # Note: The top-level words field should be present in the API response but + # may not be preserved by the SDK currently. We check for it but don't fail + # the test if it's missing, as the speaker_segments contain all the word data. + if "words" in response_dict and response_dict["words"] is not None: + assert isinstance(response_dict["words"], list) + assert len(response_dict["words"]) > 0 + + # Validate each word in top-level words + for word in response_dict["words"]: + assert "id" in word + assert "word" in word + assert "start" in word + assert "end" in word + assert "speaker_id" in word + else: + # Log that words field is missing (expected with current SDK) + print("Note: Top-level 'words' field not preserved by SDK (known issue)") + + class TestTogetherTranscriptions: @pytest.fixture def sync_together_client(self) -> Together: @@ -116,3 +162,96 @@ def test_language_detection_hindi(self, sync_together_client): assert len(response.text) > 0 assert hasattr(response, "language") assert response.language == "hi" + + def test_diarization_default(self, sync_together_client): + """ + Test diarization with default model in verbose JSON format + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3-test", + response_format="verbose_json", + diarize=True, + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Validate diarization fields + response_dict = response.model_dump() + validate_diarization_response(response_dict) + + def test_diarization_nvidia(self, sync_together_client): + """ + Test diarization with nvidia model in verbose JSON format + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3-test", + response_format="verbose_json", + diarize=True, + diarization_model="nvidia", + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Validate diarization fields + response_dict = response.model_dump() + validate_diarization_response(response_dict) + + def test_diarization_pyannote(self, sync_together_client): + """ + Test diarization with pyannote model in verbose JSON format + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3-test", + response_format="verbose_json", + diarize=True, + diarization_model="pyannote", + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Validate diarization fields + response_dict = response.model_dump() + validate_diarization_response(response_dict) + + def test_no_diarization(self, sync_together_client): + """ + Test with diarize=false should not have speaker segments + """ + audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav" + + response = sync_together_client.audio.transcriptions.create( + file=audio_url, + model="openai/whisper-large-v3-test", + response_format="verbose_json", + diarize=False, + ) + + assert isinstance(response, AudioTranscriptionVerboseResponse) + assert isinstance(response.text, str) + assert len(response.text) > 0 + + # Verify no diarization fields + response_dict = response.model_dump() + assert response_dict.get('speaker_segments') is None + assert response_dict.get('words') is None + + # Should still have standard fields + assert 'text' in response_dict + assert 'language' in response_dict + assert 'duration' in response_dict + assert 'segments' in response_dict From 53d2a3c195eec353b18ec18d42b8ab98593d2158 Mon Sep 17 00:00:00 2001 From: Garima Dhanania Date: Mon, 8 Sep 2025 17:19:04 -0700 Subject: [PATCH 2/4] changed to main model --- tests/integration/resources/test_transcriptions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/resources/test_transcriptions.py b/tests/integration/resources/test_transcriptions.py index fbc5a91..8e4901a 100644 --- a/tests/integration/resources/test_transcriptions.py +++ b/tests/integration/resources/test_transcriptions.py @@ -171,7 +171,7 @@ def test_diarization_default(self, sync_together_client): response = sync_together_client.audio.transcriptions.create( file=audio_url, - model="openai/whisper-large-v3-test", + model="openai/whisper-large-v3", response_format="verbose_json", diarize=True, ) @@ -192,7 +192,7 @@ def test_diarization_nvidia(self, sync_together_client): response = sync_together_client.audio.transcriptions.create( file=audio_url, - model="openai/whisper-large-v3-test", + model="openai/whisper-large-v3", response_format="verbose_json", diarize=True, diarization_model="nvidia", @@ -214,7 +214,7 @@ def test_diarization_pyannote(self, sync_together_client): response = sync_together_client.audio.transcriptions.create( file=audio_url, - model="openai/whisper-large-v3-test", + model="openai/whisper-large-v3", response_format="verbose_json", diarize=True, diarization_model="pyannote", @@ -236,7 +236,7 @@ def test_no_diarization(self, sync_together_client): response = sync_together_client.audio.transcriptions.create( file=audio_url, - model="openai/whisper-large-v3-test", + model="openai/whisper-large-v3", response_format="verbose_json", diarize=False, ) From 96363f9d220e45533625a0301b25c0dae11746c8 Mon Sep 17 00:00:00 2001 From: Garima Dhanania Date: Mon, 15 Sep 2025 13:09:30 -0700 Subject: [PATCH 3/4] fixed top-level words --- .../resources/audio/transcriptions.py | 14 +++++++-- src/together/types/audio_speech.py | 3 +- .../resources/test_transcriptions.py | 29 ++++++++----------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/together/resources/audio/transcriptions.py b/src/together/resources/audio/transcriptions.py index 49b83dd..49aea2a 100644 --- a/src/together/resources/audio/transcriptions.py +++ b/src/together/resources/audio/transcriptions.py @@ -104,7 +104,12 @@ def create( ) # Add any additional kwargs - params_data.update(kwargs) + # Convert boolean values to lowercase strings for proper form encoding + for key, value in kwargs.items(): + if isinstance(value, bool): + params_data[key] = str(value).lower() + else: + params_data[key] = value try: response, _, _ = requestor.request( @@ -235,7 +240,12 @@ async def create( ) # Add any additional kwargs - params_data.update(kwargs) + # Convert boolean values to lowercase strings for proper form encoding + for key, value in kwargs.items(): + if isinstance(value, bool): + params_data[key] = str(value).lower() + else: + params_data[key] = value try: response, _, _ = await requestor.arequest( diff --git a/src/together/types/audio_speech.py b/src/together/types/audio_speech.py index 09e7ddc..bb54cc7 100644 --- a/src/together/types/audio_speech.py +++ b/src/together/types/audio_speech.py @@ -176,8 +176,7 @@ class AudioTranscriptionResponse(BaseModel): class AudioTranscriptionVerboseResponse(BaseModel): - model_config = ConfigDict(extra="allow") - + id: Optional[str] = None language: Optional[str] = None duration: Optional[float] = None text: str diff --git a/tests/integration/resources/test_transcriptions.py b/tests/integration/resources/test_transcriptions.py index 8e4901a..37be253 100644 --- a/tests/integration/resources/test_transcriptions.py +++ b/tests/integration/resources/test_transcriptions.py @@ -36,23 +36,18 @@ def validate_diarization_response(response_dict): assert "end" in word assert "speaker_id" in word - # Note: The top-level words field should be present in the API response but - # may not be preserved by the SDK currently. We check for it but don't fail - # the test if it's missing, as the speaker_segments contain all the word data. - if "words" in response_dict and response_dict["words"] is not None: - assert isinstance(response_dict["words"], list) - assert len(response_dict["words"]) > 0 - - # Validate each word in top-level words - for word in response_dict["words"]: - assert "id" in word - assert "word" in word - assert "start" in word - assert "end" in word - assert "speaker_id" in word - else: - # Log that words field is missing (expected with current SDK) - print("Note: Top-level 'words' field not preserved by SDK (known issue)") + # Validate top-level words field + assert "words" in response_dict + assert isinstance(response_dict["words"], list) + assert len(response_dict["words"]) > 0 + + # Validate each word in top-level words + for word in response_dict["words"]: + assert "id" in word + assert "word" in word + assert "start" in word + assert "end" in word + assert "speaker_id" in word class TestTogetherTranscriptions: From a9a49164a07bd85bddbb8143b76f2a11ebce9911 Mon Sep 17 00:00:00 2001 From: Garima Dhanania Date: Mon, 15 Sep 2025 13:17:38 -0700 Subject: [PATCH 4/4] formatting --- tests/integration/resources/test_transcriptions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/integration/resources/test_transcriptions.py b/tests/integration/resources/test_transcriptions.py index 37be253..d8afd36 100644 --- a/tests/integration/resources/test_transcriptions.py +++ b/tests/integration/resources/test_transcriptions.py @@ -242,11 +242,11 @@ def test_no_diarization(self, sync_together_client): # Verify no diarization fields response_dict = response.model_dump() - assert response_dict.get('speaker_segments') is None - assert response_dict.get('words') is None - + assert response_dict.get("speaker_segments") is None + assert response_dict.get("words") is None + # Should still have standard fields - assert 'text' in response_dict - assert 'language' in response_dict - assert 'duration' in response_dict - assert 'segments' in response_dict + assert "text" in response_dict + assert "language" in response_dict + assert "duration" in response_dict + assert "segments" in response_dict