Skip to content

Commit d4fa774

Browse files
MarkDaoustmarkmcd
authored andcommitted
Handle max batch size for embeddings. (google-gemini#83)
1 parent 1285c29 commit d4fa774

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

google/generativeai/text.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from __future__ import annotations
1616

1717
import dataclasses
18-
from collections.abc import Sequence
19-
from typing import Iterable, overload
18+
from collections.abc import Iterable, Sequence
19+
import itertools
20+
from typing import Iterable, overload, TypeVar
2021

2122
import google.ai.generativelanguage as glm
2223

@@ -28,6 +29,26 @@
2829
from google.generativeai.types import safety_types
2930

3031
DEFAULT_TEXT_MODEL = "models/text-bison-001"
32+
EMBEDDING_MAX_BATCH_SIZE = 100
33+
34+
try:
35+
# python 3.12+
36+
_batched = itertools.batched # type: ignore
37+
except AttributeError:
38+
T = TypeVar("T")
39+
40+
def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
41+
if n < 1:
42+
raise ValueError(f"Batch size `n` must be >1, got: {n}")
43+
batch = []
44+
for item in iterable:
45+
batch.append(item)
46+
if len(batch) == n:
47+
yield batch
48+
batch = []
49+
50+
if batch:
51+
yield batch
3152

3253

3354
def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt:
@@ -282,9 +303,13 @@ def generate_embeddings(
282303
embedding_dict = type(embedding_response).to_dict(embedding_response)
283304
embedding_dict["embedding"] = embedding_dict["embedding"]["value"]
284305
else:
285-
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=text)
286-
embedding_response = client.batch_embed_text(embedding_request)
287-
embedding_dict = type(embedding_response).to_dict(embedding_response)
288-
embedding_dict["embedding"] = [e["value"] for e in embedding_dict["embeddings"]]
306+
result = {"embedding": []}
307+
for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE):
308+
# TODO(markdaoust): This could use an option for returning an iterator or wait-bar.
309+
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch)
310+
embedding_response = client.batch_embed_text(embedding_request)
311+
embedding_dict = type(embedding_response).to_dict(embedding_response)
312+
result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"])
313+
return result
289314

290315
return embedding_dict

tests/test_text.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
import math
1617
import unittest
1718
import unittest.mock as mock
1819

@@ -61,7 +62,10 @@ def batch_embed_text(
6162
request: glm.EmbedTextRequest,
6263
) -> glm.EmbedTextResponse:
6364
self.observed_requests.append(request)
64-
return self.responses["batch_embed_text"]
65+
66+
return glm.BatchEmbedTextResponse(
67+
embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts)
68+
)
6569

6670
@add_client_method
6771
def count_text_tokens(
@@ -120,27 +124,46 @@ def test_generate_embeddings(self, model, text):
120124
@parameterized.named_parameters(
121125
[
122126
dict(
123-
testcase_name="basic_model",
127+
testcase_name="small-2",
124128
model="models/chat-lamda-001",
125129
text=["Who are you?", "Who am I?"],
126-
)
130+
),
131+
dict(
132+
testcase_name="even-batch",
133+
model="models/chat-lamda-001",
134+
text=["Who are you?"] * 100,
135+
),
136+
dict(
137+
testcase_name="even-batch-plus-one",
138+
model="models/chat-lamda-001",
139+
text=["Who are you?"] * 101,
140+
),
141+
dict(
142+
testcase_name="odd-batch",
143+
model="models/chat-lamda-001",
144+
text=["Who are you?"] * 237,
145+
),
127146
]
128147
)
129148
def test_generate_embeddings_batch(self, model, text):
130-
self.responses["batch_embed_text"] = glm.BatchEmbedTextResponse(
131-
embeddings=[
132-
glm.Embedding(value=[1, 2, 3]),
133-
glm.Embedding(value=[4, 5, 6]),
134-
]
135-
)
136-
137149
emb = text_service.generate_embeddings(model=model, text=text)
138150

139151
self.assertIsInstance(emb, dict)
140-
self.assertEqual(
141-
self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text)
142-
)
152+
153+
# Check first and last requests.
154+
self.assertEqual(self.observed_requests[-1].model, model)
155+
self.assertEqual(self.observed_requests[-1].texts[-1], text[-1])
156+
self.assertEqual(self.observed_requests[0].texts[0], text[0])
157+
158+
# Check that the list has the right length.
143159
self.assertIsInstance(emb["embedding"][0], list)
160+
self.assertLen(emb["embedding"], len(text))
161+
162+
# Check that the right number of requests were sent.
163+
self.assertLen(
164+
self.observed_requests,
165+
math.ceil(len(text) / text_service.EMBEDDING_MAX_BATCH_SIZE),
166+
)
144167

145168
@parameterized.named_parameters(
146169
[

0 commit comments

Comments
 (0)