|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | import copy
|
| 16 | +import math |
16 | 17 | import unittest
|
17 | 18 | import unittest.mock as mock
|
18 | 19 |
|
@@ -61,7 +62,10 @@ def batch_embed_text(
|
61 | 62 | request: glm.EmbedTextRequest,
|
62 | 63 | ) -> glm.EmbedTextResponse:
|
63 | 64 | self.observed_requests.append(request)
|
64 |
| - return self.responses["batch_embed_text"] |
| 65 | + |
| 66 | + return glm.BatchEmbedTextResponse( |
| 67 | + embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts) |
| 68 | + ) |
65 | 69 |
|
66 | 70 | @add_client_method
|
67 | 71 | def count_text_tokens(
|
@@ -120,27 +124,46 @@ def test_generate_embeddings(self, model, text):
|
120 | 124 | @parameterized.named_parameters(
|
121 | 125 | [
|
122 | 126 | dict(
|
123 |
| - testcase_name="basic_model", |
| 127 | + testcase_name="small-2", |
124 | 128 | model="models/chat-lamda-001",
|
125 | 129 | text=["Who are you?", "Who am I?"],
|
126 |
| - ) |
| 130 | + ), |
| 131 | + dict( |
| 132 | + testcase_name="even-batch", |
| 133 | + model="models/chat-lamda-001", |
| 134 | + text=["Who are you?"] * 100, |
| 135 | + ), |
| 136 | + dict( |
| 137 | + testcase_name="even-batch-plus-one", |
| 138 | + model="models/chat-lamda-001", |
| 139 | + text=["Who are you?"] * 101, |
| 140 | + ), |
| 141 | + dict( |
| 142 | + testcase_name="odd-batch", |
| 143 | + model="models/chat-lamda-001", |
| 144 | + text=["Who are you?"] * 237, |
| 145 | + ), |
127 | 146 | ]
|
128 | 147 | )
|
129 | 148 | def test_generate_embeddings_batch(self, model, text):
|
130 |
| - self.responses["batch_embed_text"] = glm.BatchEmbedTextResponse( |
131 |
| - embeddings=[ |
132 |
| - glm.Embedding(value=[1, 2, 3]), |
133 |
| - glm.Embedding(value=[4, 5, 6]), |
134 |
| - ] |
135 |
| - ) |
136 |
| - |
137 | 149 | emb = text_service.generate_embeddings(model=model, text=text)
|
138 | 150 |
|
139 | 151 | self.assertIsInstance(emb, dict)
|
140 |
| - self.assertEqual( |
141 |
| - self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text) |
142 |
| - ) |
| 152 | + |
| 153 | + # Check first and last requests. |
| 154 | + self.assertEqual(self.observed_requests[-1].model, model) |
| 155 | + self.assertEqual(self.observed_requests[-1].texts[-1], text[-1]) |
| 156 | + self.assertEqual(self.observed_requests[0].texts[0], text[0]) |
| 157 | + |
| 158 | + # Check that the list has the right length. |
143 | 159 | self.assertIsInstance(emb["embedding"][0], list)
|
| 160 | + self.assertLen(emb["embedding"], len(text)) |
| 161 | + |
| 162 | + # Check that the right number of requests were sent. |
| 163 | + self.assertLen( |
| 164 | + self.observed_requests, |
| 165 | + math.ceil(len(text) / text_service.EMBEDDING_MAX_BATCH_SIZE), |
| 166 | + ) |
144 | 167 |
|
145 | 168 | @parameterized.named_parameters(
|
146 | 169 | [
|
|
0 commit comments