Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/together/resources/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@ def create(
)

# Add any additional kwargs
params_data.update(kwargs)
# Convert boolean values to lowercase strings for proper form encoding
for key, value in kwargs.items():
if isinstance(value, bool):
params_data[key] = str(value).lower()
else:
params_data[key] = value

try:
response, _, _ = requestor.request(
Expand All @@ -131,7 +136,8 @@ def create(
response_format == "verbose_json"
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
):
return AudioTranscriptionVerboseResponse(**response.data)
# Create response with model validation that preserves extra fields
return AudioTranscriptionVerboseResponse.model_validate(response.data)
else:
return AudioTranscriptionResponse(**response.data)

Expand Down Expand Up @@ -234,7 +240,12 @@ async def create(
)

# Add any additional kwargs
params_data.update(kwargs)
# Convert boolean values to lowercase strings for proper form encoding
for key, value in kwargs.items():
if isinstance(value, bool):
params_data[key] = str(value).lower()
else:
params_data[key] = value

try:
response, _, _ = await requestor.arequest(
Expand All @@ -261,6 +272,7 @@ async def create(
response_format == "verbose_json"
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
):
return AudioTranscriptionVerboseResponse(**response.data)
# Create response with model validation that preserves extra fields
return AudioTranscriptionVerboseResponse.model_validate(response.data)
else:
return AudioTranscriptionResponse(**response.data)
13 changes: 13 additions & 0 deletions src/together/types/audio_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,31 @@ class AudioTranscriptionWord(BaseModel):
word: str
start: float
end: float
id: Optional[int] = None
speaker_id: Optional[str] = None


class AudioSpeakerSegment(BaseModel):
id: int
speaker_id: str
start: float
end: float
text: str
words: List[AudioTranscriptionWord]


class AudioTranscriptionResponse(BaseModel):
text: str


class AudioTranscriptionVerboseResponse(BaseModel):
id: Optional[str] = None
language: Optional[str] = None
duration: Optional[float] = None
text: str
segments: Optional[List[AudioTranscriptionSegment]] = None
words: Optional[List[AudioTranscriptionWord]] = None
speaker_segments: Optional[List[AudioSpeakerSegment]] = None


class AudioTranslationResponse(BaseModel):
Expand Down
134 changes: 134 additions & 0 deletions tests/integration/resources/test_transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,47 @@
)


def validate_diarization_response(response_dict):
"""
Helper function to validate diarization response structure
"""
# Validate top-level speaker_segments field
assert "speaker_segments" in response_dict
assert isinstance(response_dict["speaker_segments"], list)
assert len(response_dict["speaker_segments"]) > 0

# Validate each speaker segment structure
for segment in response_dict["speaker_segments"]:
assert "text" in segment
assert "id" in segment
assert "speaker_id" in segment
assert "start" in segment
assert "end" in segment
assert "words" in segment

# Validate nested words in speaker segments
assert isinstance(segment["words"], list)
for word in segment["words"]:
assert "id" in word
assert "word" in word
assert "start" in word
assert "end" in word
assert "speaker_id" in word

# Validate top-level words field
assert "words" in response_dict
assert isinstance(response_dict["words"], list)
assert len(response_dict["words"]) > 0

# Validate each word in top-level words
for word in response_dict["words"]:
assert "id" in word
assert "word" in word
assert "start" in word
assert "end" in word
assert "speaker_id" in word


class TestTogetherTranscriptions:
@pytest.fixture
def sync_together_client(self) -> Together:
Expand Down Expand Up @@ -116,3 +157,96 @@ def test_language_detection_hindi(self, sync_together_client):
assert len(response.text) > 0
assert hasattr(response, "language")
assert response.language == "hi"

def test_diarization_default(self, sync_together_client):
"""
Test diarization with default model in verbose JSON format
"""
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"

response = sync_together_client.audio.transcriptions.create(
file=audio_url,
model="openai/whisper-large-v3",
response_format="verbose_json",
diarize=True,
)

assert isinstance(response, AudioTranscriptionVerboseResponse)
assert isinstance(response.text, str)
assert len(response.text) > 0

# Validate diarization fields
response_dict = response.model_dump()
validate_diarization_response(response_dict)

def test_diarization_nvidia(self, sync_together_client):
"""
Test diarization with nvidia model in verbose JSON format
"""
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"

response = sync_together_client.audio.transcriptions.create(
file=audio_url,
model="openai/whisper-large-v3",
response_format="verbose_json",
diarize=True,
diarization_model="nvidia",
)

assert isinstance(response, AudioTranscriptionVerboseResponse)
assert isinstance(response.text, str)
assert len(response.text) > 0

# Validate diarization fields
response_dict = response.model_dump()
validate_diarization_response(response_dict)

def test_diarization_pyannote(self, sync_together_client):
"""
Test diarization with pyannote model in verbose JSON format
"""
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"

response = sync_together_client.audio.transcriptions.create(
file=audio_url,
model="openai/whisper-large-v3",
response_format="verbose_json",
diarize=True,
diarization_model="pyannote",
)

assert isinstance(response, AudioTranscriptionVerboseResponse)
assert isinstance(response.text, str)
assert len(response.text) > 0

# Validate diarization fields
response_dict = response.model_dump()
validate_diarization_response(response_dict)

def test_no_diarization(self, sync_together_client):
"""
Test with diarize=false should not have speaker segments
"""
audio_url = "https://together-public-test-data.s3.us-west-2.amazonaws.com/audio/2-speaker-conversation.wav"

response = sync_together_client.audio.transcriptions.create(
file=audio_url,
model="openai/whisper-large-v3",
response_format="verbose_json",
diarize=False,
)

assert isinstance(response, AudioTranscriptionVerboseResponse)
assert isinstance(response.text, str)
assert len(response.text) > 0

# Verify no diarization fields
response_dict = response.model_dump()
assert response_dict.get("speaker_segments") is None
assert response_dict.get("words") is None

# Should still have standard fields
assert "text" in response_dict
assert "language" in response_dict
assert "duration" in response_dict
assert "segments" in response_dict