@@ -18,18 +18,29 @@ package com.google.firebase.ai
1818
1919import com.google.firebase.ai.type.BlockReason
2020import com.google.firebase.ai.type.FinishReason
21+ import com.google.firebase.ai.type.FunctionCallPart
2122import com.google.firebase.ai.type.PromptBlockedException
2223import com.google.firebase.ai.type.ResponseStoppedException
2324import com.google.firebase.ai.type.ServerException
25+ import com.google.firebase.ai.type.content
2426import com.google.firebase.ai.util.goldenDevAPIStreamingFile
2527import io.kotest.assertions.throwables.shouldThrow
2628import io.kotest.matchers.collections.shouldBeEmpty
29+ import io.kotest.matchers.collections.shouldNotBeEmpty
30+ import io.kotest.matchers.nulls.shouldNotBeNull
2731import io.kotest.matchers.shouldBe
32+ import io.kotest.matchers.string.shouldStartWith
33+ import io.ktor.client.engine.mock.toByteArray
34+ import io.ktor.client.request.HttpRequestData
2835import io.ktor.http.HttpStatusCode
2936import kotlin.time.Duration.Companion.seconds
3037import kotlinx.coroutines.flow.collect
3138import kotlinx.coroutines.flow.toList
3239import kotlinx.coroutines.withTimeout
40+ import kotlinx.serialization.json.Json
41+ import kotlinx.serialization.json.jsonArray
42+ import kotlinx.serialization.json.jsonObject
43+ import kotlinx.serialization.json.jsonPrimitive
3344import org.junit.Test
3445import org.junit.runner.RunWith
3546import org.robolectric.RobolectricTestRunner
@@ -85,6 +96,86 @@ internal class DevAPIStreamingSnapshotTests {
8596 }
8697 }
8798
99+ @Test
100+ fun `success call with thought summary and signature` () =
101+ goldenDevAPIStreamingFile(
102+ " streaming-success-thinking-function-call-thought-summary-signature.txt"
103+ ) {
104+ val responses = model.generateContentStream(" prompt" )
105+
106+ withTimeout(testTimeout) {
107+ val responseList = responses.toList()
108+ responseList.isEmpty() shouldBe false
109+ val functionCallResponse = responseList.find { it.functionCalls.isNotEmpty() }
110+ functionCallResponse.shouldNotBeNull()
111+ functionCallResponse.functionCalls.first().let {
112+ it.thoughtSignature.shouldNotBeNull()
113+ it.thoughtSignature.shouldStartWith(" CiIBVKhc7vB" )
114+ }
115+ }
116+ }
117+
118+ @Test
119+ fun `chat call with history including thought summary and signature` () {
120+ var capturedRequest: HttpRequestData ? = null
121+ goldenDevAPIStreamingFile(
122+ " streaming-success-thinking-function-call-thought-summary-signature.txt" ,
123+ requestHandler = { capturedRequest = it }
124+ ) {
125+ val chat = model.startChat()
126+ val firstPrompt = content { text(" first prompt" ) }
127+ val secondPrompt = content { text(" second prompt" ) }
128+ val responses = chat.sendMessageStream(firstPrompt)
129+
130+ withTimeout(testTimeout) {
131+ val responseList = responses.toList()
132+ responseList.shouldNotBeEmpty()
133+
134+ chat.history.let { history ->
135+ history.contains(firstPrompt)
136+ val functionCallPart =
137+ history.flatMap { it.parts }.first { it is FunctionCallPart } as FunctionCallPart
138+ functionCallPart.let {
139+ it.thoughtSignature.shouldNotBeNull()
140+ it.thoughtSignature.shouldStartWith(" CiIBVKhc7vB" )
141+ }
142+ }
143+
144+ // Reset the request so we can be sure we capture the latest version
145+ capturedRequest = null
146+
147+ // We don't care about the response, only the request
148+ val unused = chat.sendMessageStream(secondPrompt).toList()
149+
150+ // Make sure the history contains all prompts seen so far
151+ chat.history.contains(firstPrompt)
152+ chat.history.contains(secondPrompt)
153+
154+ // Put the captured request into a `val` to enable smart casting
155+ val request = capturedRequest
156+ request.shouldNotBeNull()
157+ val bodyAsString = request.body.toByteArray().decodeToString()
158+ bodyAsString.shouldNotBeNull()
159+
160+ val rootElement = Json .parseToJsonElement(bodyAsString).jsonObject
161+
162+ // Traverse the tree: contents -> parts -> thoughtSignature
163+ val contents = rootElement[" contents" ]?.jsonArray
164+
165+ val signature =
166+ contents?.firstNotNullOfOrNull { content ->
167+ content.jsonObject[" parts" ]?.jsonArray?.firstNotNullOfOrNull { part ->
168+ // resulting value is a JsonPrimitive, so we use .content to get the string
169+ part.jsonObject[" thoughtSignature" ]?.jsonPrimitive?.content
170+ }
171+ }
172+
173+ signature.shouldNotBeNull()
174+ signature.shouldStartWith(" CiIBVKhc7vB" )
175+ }
176+ }
177+ }
178+
88179 @Test
89180 fun `prompt blocked for safety` () =
90181 goldenDevAPIStreamingFile(" streaming-failure-prompt-blocked-safety.txt" ) {
0 commit comments