Skip to content

Commit 964dccc

Browse files
authored
[AI] Chat history must store all the parts (streaming) (#7562)
The chat history in streaming mode reconstructs the parts from their contents, rather than storing the parts themselves. This causes non-visible elements, like `thoughtSignature` to get lost.
1 parent 9cc0669 commit 964dccc

File tree

4 files changed

+129
-13
lines changed

4 files changed

+129
-13
lines changed

firebase-ai/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Unreleased
22

3+
- [fixed] Fixed an issue causing streaming chat interactions to drop thought signatures. (#7562)
34
- [feature] Added support for server templates via `TemplateGenerativeModel` and
45
`TemplateImagenModel`. (#7503)
56

firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import com.google.firebase.ai.type.GenerateContentResponse
2222
import com.google.firebase.ai.type.ImagePart
2323
import com.google.firebase.ai.type.InlineDataPart
2424
import com.google.firebase.ai.type.InvalidStateException
25+
import com.google.firebase.ai.type.Part
2526
import com.google.firebase.ai.type.TextPart
2627
import com.google.firebase.ai.type.content
2728
import java.util.LinkedList
@@ -133,6 +134,7 @@ public class Chat(
133134
val bitmaps = LinkedList<Bitmap>()
134135
val inlineDataParts = LinkedList<InlineDataPart>()
135136
val text = StringBuilder()
137+
val parts = mutableListOf<Part>()
136138

137139
/**
138140
* TODO: revisit when images and inline data are returned. This will cause issues with how
@@ -147,22 +149,17 @@ public class Chat(
147149
is ImagePart -> bitmaps.add(part.image)
148150
is InlineDataPart -> inlineDataParts.add(part)
149151
}
152+
parts.add(part)
150153
}
151154
}
152155
.onCompletion {
153156
lock.release()
154157
if (it == null) {
155158
val content =
156159
content("model") {
157-
for (bitmap in bitmaps) {
158-
image(bitmap)
159-
}
160-
for (inlineDataPart in inlineDataParts) {
161-
inlineData(inlineDataPart.inlineData, inlineDataPart.mimeType)
162-
}
163-
if (text.isNotBlank()) {
164-
text(text.toString())
165-
}
160+
setParts(
161+
parts.filterNot { part -> part is TextPart && !part.hasContent() }.toMutableList()
162+
)
166163
}
167164

168165
history.add(prompt)
@@ -224,3 +221,12 @@ public class Chat(
224221
}
225222
}
226223
}
224+
225+
/**
226+
* Returns true if the [TextPart] contains any content, either in its [TextPart.text] property or
227+
* its [TextPart.thoughtSignature] property.
228+
*/
229+
private fun TextPart.hasContent(): Boolean {
230+
if (text.isNotEmpty()) return true
231+
return !thoughtSignature.isNullOrBlank()
232+
}

firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,29 @@ package com.google.firebase.ai
1818

1919
import com.google.firebase.ai.type.BlockReason
2020
import com.google.firebase.ai.type.FinishReason
21+
import com.google.firebase.ai.type.FunctionCallPart
2122
import com.google.firebase.ai.type.PromptBlockedException
2223
import com.google.firebase.ai.type.ResponseStoppedException
2324
import com.google.firebase.ai.type.ServerException
25+
import com.google.firebase.ai.type.content
2426
import com.google.firebase.ai.util.goldenDevAPIStreamingFile
2527
import io.kotest.assertions.throwables.shouldThrow
2628
import io.kotest.matchers.collections.shouldBeEmpty
29+
import io.kotest.matchers.collections.shouldNotBeEmpty
30+
import io.kotest.matchers.nulls.shouldNotBeNull
2731
import io.kotest.matchers.shouldBe
32+
import io.kotest.matchers.string.shouldStartWith
33+
import io.ktor.client.engine.mock.toByteArray
34+
import io.ktor.client.request.HttpRequestData
2835
import io.ktor.http.HttpStatusCode
2936
import kotlin.time.Duration.Companion.seconds
3037
import kotlinx.coroutines.flow.collect
3138
import kotlinx.coroutines.flow.toList
3239
import kotlinx.coroutines.withTimeout
40+
import kotlinx.serialization.json.Json
41+
import kotlinx.serialization.json.jsonArray
42+
import kotlinx.serialization.json.jsonObject
43+
import kotlinx.serialization.json.jsonPrimitive
3344
import org.junit.Test
3445
import org.junit.runner.RunWith
3546
import org.robolectric.RobolectricTestRunner
@@ -85,6 +96,86 @@ internal class DevAPIStreamingSnapshotTests {
8596
}
8697
}
8798

99+
@Test
100+
fun `success call with thought summary and signature`() =
101+
goldenDevAPIStreamingFile(
102+
"streaming-success-thinking-function-call-thought-summary-signature.txt"
103+
) {
104+
val responses = model.generateContentStream("prompt")
105+
106+
withTimeout(testTimeout) {
107+
val responseList = responses.toList()
108+
responseList.isEmpty() shouldBe false
109+
val functionCallResponse = responseList.find { it.functionCalls.isNotEmpty() }
110+
functionCallResponse.shouldNotBeNull()
111+
functionCallResponse.functionCalls.first().let {
112+
it.thoughtSignature.shouldNotBeNull()
113+
it.thoughtSignature.shouldStartWith("CiIBVKhc7vB")
114+
}
115+
}
116+
}
117+
118+
@Test
119+
fun `chat call with history including thought summary and signature`() {
120+
var capturedRequest: HttpRequestData? = null
121+
goldenDevAPIStreamingFile(
122+
"streaming-success-thinking-function-call-thought-summary-signature.txt",
123+
requestHandler = { capturedRequest = it }
124+
) {
125+
val chat = model.startChat()
126+
val firstPrompt = content { text("first prompt") }
127+
val secondPrompt = content { text("second prompt") }
128+
val responses = chat.sendMessageStream(firstPrompt)
129+
130+
withTimeout(testTimeout) {
131+
val responseList = responses.toList()
132+
responseList.shouldNotBeEmpty()
133+
134+
chat.history.let { history ->
135+
history.contains(firstPrompt)
136+
val functionCallPart =
137+
history.flatMap { it.parts }.first { it is FunctionCallPart } as FunctionCallPart
138+
functionCallPart.let {
139+
it.thoughtSignature.shouldNotBeNull()
140+
it.thoughtSignature.shouldStartWith("CiIBVKhc7vB")
141+
}
142+
}
143+
144+
// Reset the request so we can be sure we capture the latest version
145+
capturedRequest = null
146+
147+
// We don't care about the response, only the request
148+
val unused = chat.sendMessageStream(secondPrompt).toList()
149+
150+
// Make sure the history contains all prompts seen so far
151+
chat.history.contains(firstPrompt)
152+
chat.history.contains(secondPrompt)
153+
154+
// Put the captured request into a `val` to enable smart casting
155+
val request = capturedRequest
156+
request.shouldNotBeNull()
157+
val bodyAsString = request.body.toByteArray().decodeToString()
158+
bodyAsString.shouldNotBeNull()
159+
160+
val rootElement = Json.parseToJsonElement(bodyAsString).jsonObject
161+
162+
// Traverse the tree: contents -> parts -> thoughtSignature
163+
val contents = rootElement["contents"]?.jsonArray
164+
165+
val signature =
166+
contents?.firstNotNullOfOrNull { content ->
167+
content.jsonObject["parts"]?.jsonArray?.firstNotNullOfOrNull { part ->
168+
// resulting value is a JsonPrimitive, so we use .content to get the string
169+
part.jsonObject["thoughtSignature"]?.jsonPrimitive?.content
170+
}
171+
}
172+
173+
signature.shouldNotBeNull()
174+
signature.shouldStartWith("CiIBVKhc7vB")
175+
}
176+
}
177+
}
178+
88179
@Test
89180
fun `prompt blocked for safety`() =
90181
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {

firebase-ai/src/test/java/com/google/firebase/ai/util/tests.kt

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ import io.kotest.matchers.collections.shouldNotBeEmpty
2929
import io.kotest.matchers.nulls.shouldNotBeNull
3030
import io.ktor.client.engine.mock.MockEngine
3131
import io.ktor.client.engine.mock.respond
32+
import io.ktor.client.request.HttpRequestData
3233
import io.ktor.http.HttpHeaders
3334
import io.ktor.http.HttpStatusCode
3435
import io.ktor.http.headersOf
3536
import io.ktor.utils.io.ByteChannel
36-
import io.ktor.utils.io.close
3737
import io.ktor.utils.io.writeFully
3838
import java.io.File
3939
import kotlinx.coroutines.launch
@@ -103,6 +103,7 @@ internal fun commonTest(
103103
status: HttpStatusCode = HttpStatusCode.OK,
104104
requestOptions: RequestOptions = RequestOptions(),
105105
backend: GenerativeBackend = GenerativeBackend.vertexAI(),
106+
requestHandler: (HttpRequestData) -> Unit = {},
106107
block: CommonTest,
107108
) = doBlocking {
108109
val channel = ByteChannel(autoFlush = true)
@@ -115,6 +116,7 @@ internal fun commonTest(
115116
"gemini-pro",
116117
requestOptions,
117118
MockEngine {
119+
requestHandler(it)
118120
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
119121
},
120122
TEST_CLIENT_ID,
@@ -144,12 +146,13 @@ internal fun goldenStreamingFile(
144146
name: String,
145147
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
146148
backend: GenerativeBackend = GenerativeBackend.vertexAI(),
149+
requestHandler: (HttpRequestData) -> Unit,
147150
block: CommonTest,
148151
) = doBlocking {
149152
val goldenFile = loadGoldenFile(name)
150153
val messages = goldenFile.readLines().filter { it.isNotBlank() }
151154

152-
commonTest(httpStatusCode, backend = backend) {
155+
commonTest(httpStatusCode, backend = backend, requestHandler = requestHandler) {
153156
launch {
154157
for (message in messages) {
155158
channel.writeFully("$message$SSE_SEPARATOR".toByteArray())
@@ -175,8 +178,15 @@ internal fun goldenStreamingFile(
175178
internal fun goldenVertexStreamingFile(
176179
name: String,
177180
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
181+
requestHandler: (HttpRequestData) -> Unit = {},
178182
block: CommonTest,
179-
) = goldenStreamingFile("vertexai/$name", httpStatusCode, block = block)
183+
) =
184+
goldenStreamingFile(
185+
"vertexai/$name",
186+
httpStatusCode,
187+
requestHandler = requestHandler,
188+
block = block
189+
)
180190

181191
/**
182192
* A variant of [goldenStreamingFile] for testing the developer api
@@ -192,8 +202,16 @@ internal fun goldenVertexStreamingFile(
192202
internal fun goldenDevAPIStreamingFile(
193203
name: String,
194204
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
205+
requestHandler: (HttpRequestData) -> Unit = {},
195206
block: CommonTest,
196-
) = goldenStreamingFile("googleai/$name", httpStatusCode, GenerativeBackend.googleAI(), block)
207+
) =
208+
goldenStreamingFile(
209+
"googleai/$name",
210+
httpStatusCode,
211+
GenerativeBackend.googleAI(),
212+
requestHandler,
213+
block
214+
)
197215

198216
/**
199217
* A variant of [commonTest] for performing snapshot tests.

0 commit comments

Comments
 (0)