Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ def _inferred_provider(self) -> str:
parts = self.model_id.split(".")
return parts[1] if parts[0] in regions else parts[0]

@property
def _is_cohere_v4(self) -> bool:
"""Check if the model is Cohere Embed v4."""
return "cohere.embed-v4" in self.model_id

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""
Expand Down Expand Up @@ -189,7 +194,12 @@ def _embedding_func(
embeddings = response_body.get("embeddings")
if embeddings is None:
raise ValueError("No embeddings returned from model")
return embeddings[0]
# Embed v3 and v4 schemas
if isinstance(embeddings, dict) and "float" in embeddings:
processed_embeddings = embeddings["float"]
else:
processed_embeddings = embeddings
return processed_embeddings[0]
else:
# includes common provider == "amazon"
response_body = self._invoke_model(
Expand All @@ -207,7 +217,9 @@ def _cohere_multi_embedding(self, texts: List[str]) -> List[List[float]]:
results: List[List[float]] = []

# Iterate through the list of strings in batches
for text_batch in _batch_cohere_embedding_texts(texts):
for text_batch in _batch_cohere_embedding_texts(
texts, is_v4=self._is_cohere_v4
):
batch_embeddings = self._invoke_model(
input_body={
"input_type": "search_document",
Expand Down Expand Up @@ -344,17 +356,35 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return list(result)


def _batch_cohere_embedding_texts(texts: List[str]) -> Generator[List[str], None, None]:
def _batch_cohere_embedding_texts(
texts: List[str], is_v4: bool = False
) -> Generator[List[str], None, None]:
"""Batches a set of texts into chunks acceptable for the Cohere embedding API.

Chunks of at most 96 items, or 2048 characters.
For Cohere Embed v3: Chunks of at most 96 items, or 2048 characters.
For Cohere Embed v4: Chunks of at most 96 items, or ~512,000 characters
(approx 128K tokens).

"""

# Cohere embeddings want a maximum of 96 items and 2048 characters
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
max_items = 96
max_chars = 2048
if is_v4:
# Cohere Embed v4 supports up to 128K tokens per input
# Using conservative estimate of ~4 chars per token = ~512K chars
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed-v4.html
max_chars = 512_000
char_limit_msg = (
"The Cohere Embed v4 embedding API does not support texts longer than "
"approximately 128K tokens (~512,000 characters)."
)
else:
# Cohere Embed v3 limit
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
max_chars = 2048
char_limit_msg = (
"The Cohere embedding API does not support texts longer than "
"2048 characters."
)

# Initialize batches
current_batch: List[str] = []
Expand All @@ -364,10 +394,7 @@ def _batch_cohere_embedding_texts(texts: List[str]) -> Generator[List[str], None
text_len = len(text)

if text_len > max_chars:
raise ValueError(
"The Cohere embedding API does not support texts longer than "
"2048 characters."
)
raise ValueError(char_limit_msg)

# Check if adding the current string would exceed the limits
if len(current_batch) >= max_items or current_chars + text_len > max_chars:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,16 @@ def test_bedrock_embedding_provider_arg(
assert cohere_embeddings_v3._inferred_provider == "cohere"
assert cohere_embeddings_v4._inferred_provider == "cohere"
assert cohere_embeddings_model_arn._inferred_provider == "cohere"


# @pytest.mark.scheduled
@pytest.mark.skip(reason="CI does not have access to v4 embeddings.")
def test_bedrock_cohere_v4_large_input(cohere_embeddings_v4) -> None:
"""Test that Cohere v4 can handle inputs larger than v3's 2048 char limit."""
# Create a text slightly larger than v3's 2048 char limit
large_text = "x" * 3000 # 3000 characters > 2048 limit of v3

# This should work with v4 (would fail with v3)
output = cohere_embeddings_v4.embed_documents([large_text])
assert len(output) == 1
assert len(output[0]) == 1536 # v4 embedding dimension
1 change: 1 addition & 0 deletions libs/aws/tests/unit_tests/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for embeddings."""
119 changes: 119 additions & 0 deletions libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Test Cohere v4 embedding fixes."""

from unittest.mock import Mock, patch

import pytest

from langchain_aws.embeddings.bedrock import (
BedrockEmbeddings,
_batch_cohere_embedding_texts,
)


class TestCohereV4Fixes:
"""Test fixes for Cohere v4 embedding support."""

def test_is_cohere_v4_property_v4_model(self) -> None:
"""Test that _is_cohere_v4 returns True for v4 models."""
embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0")
assert embeddings._is_cohere_v4 is True

def test_is_cohere_v4_property_v3_model(self) -> None:
"""Test that _is_cohere_v4 returns False for v3 models."""
embeddings = BedrockEmbeddings(model_id="cohere.embed-english-v3")
assert embeddings._is_cohere_v4 is False

def test_is_cohere_v4_property_non_cohere_model(self) -> None:
"""Test that _is_cohere_v4 returns False for non-Cohere models."""
embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
assert embeddings._is_cohere_v4 is False

def test_batch_cohere_v3_limits(self) -> None:
"""Test that v3 batching respects 2048 character limit."""
# Test text under limit
short_texts = ["hello"] * 10
batches = list(_batch_cohere_embedding_texts(short_texts, is_v4=False))
assert len(batches) == 1
assert len(batches[0]) == 10

# Test text over limit
long_text = "x" * 2049
with pytest.raises(ValueError) as exc_info:
list(_batch_cohere_embedding_texts([long_text], is_v4=False))
assert "2048 characters" in str(exc_info.value)

def test_batch_cohere_v4_limits(self) -> None:
"""Test that v4 batching respects higher character limit."""
# Test text that would fail on v3 but pass on v4
medium_text = "x" * 10000 # > 2048 but << 512k
batches = list(_batch_cohere_embedding_texts([medium_text], is_v4=True))
assert len(batches) == 1
assert len(batches[0]) == 1

# Test text over v4 limit
huge_text = "x" * 600000 # > 512k chars
with pytest.raises(ValueError) as exc_info:
list(_batch_cohere_embedding_texts([huge_text], is_v4=True))
assert "128K tokens" in str(exc_info.value)

@patch("langchain_aws.embeddings.bedrock.create_aws_client")
def test_embedding_func_cohere_v3_schema(self, mock_create_client: Mock) -> None:
"""Test that _embedding_func handles v3 schema correctly."""
mock_client = Mock()
mock_create_client.return_value = mock_client

# Mock v3 response (direct array)
mock_client.invoke_model.return_value = {
"body": Mock(read=lambda: '{"embeddings": [[0.1, 0.2, 0.3]]}')
}

embeddings = BedrockEmbeddings(model_id="cohere.embed-english-v3")
result = embeddings._embedding_func("test text")

assert result == [0.1, 0.2, 0.3]

@patch("langchain_aws.embeddings.bedrock.create_aws_client")
def test_embedding_func_cohere_v4_schema(self, mock_create_client: Mock) -> None:
"""Test that _embedding_func handles v4 schema correctly."""
mock_client = Mock()
mock_create_client.return_value = mock_client

# Mock v4 response (dict with "float" key)
mock_client.invoke_model.return_value = {
"body": Mock(read=lambda: '{"embeddings": {"float": [[0.1, 0.2, 0.3]]}}')
}

embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0")
result = embeddings._embedding_func("test text")

assert result == [0.1, 0.2, 0.3]

@patch("langchain_aws.embeddings.bedrock.create_aws_client")
def test_cohere_multi_embedding_uses_v4_batching(
self, mock_create_client: Mock
) -> None:
"""Test that _cohere_multi_embedding passes v4 flag to batching function."""
mock_client = Mock()
mock_create_client.return_value = mock_client

mock_client.invoke_model.return_value = {
"body": Mock(
read=lambda: '{"embeddings": {"float": [[0.1, 0.2], [0.3, 0.4]]}}'
)
}

embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0")

# Use a text that would fail v3 limits but pass v4 limits
medium_texts = ["x" * 3000, "y" * 3000]

with patch(
"langchain_aws.embeddings.bedrock._batch_cohere_embedding_texts"
) as mock_batch:
mock_batch.return_value = [medium_texts] # Single batch

result = embeddings._cohere_multi_embedding(medium_texts)

# Verify the batching function was called with is_v4=True
mock_batch.assert_called_once_with(medium_texts, is_v4=True)
assert len(result) == 2