Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cmd/agentcli/prestage.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func runPreStage(cfg cliConfig, messages []oai.Message, stderr io.Writer) ([]oai
}
}
prepMessages = append(prepMessages, applyTranscriptHygiene(normalizedIn, cfg.debug)...)
req := oai.ChatCompletionsRequest{
req := oai.ChatCompletionsRequest{
Model: prepModel,
Messages: prepMessages,
}
Expand All @@ -168,7 +168,14 @@ func runPreStage(cfg cliConfig, messages []oai.Message, stderr io.Writer) ([]oai
} else if effectiveTemp != nil {
req.Temperature = effectiveTemp
}
// Create a dedicated client honoring pre-stage timeout and normal retry policy
// Enforce prompt to fit context window for pre-stage as well
window := oai.ContextWindowForModel(prepModel)
promptBudget := oai.PromptTokenBudget(window, 0)
if oai.EstimateTokens(req.Messages) > promptBudget {
req.Messages = oai.TrimMessagesToFit(req.Messages, promptBudget)
}

// Create a dedicated client honoring pre-stage timeout and normal retry policy
httpClient := oai.NewClientWithRetry(prepBaseURL, prepAPIKey, cfg.prepHTTPTimeout, oai.RetryPolicy{MaxRetries: retries, Backoff: backoff})
dumpJSONIfDebug(stderr, "prep.request", req, cfg.debug)
// Tag context with audit stage so HTTP audit lines include stage: "prep"
Expand Down
14 changes: 11 additions & 3 deletions cmd/agentcli/run_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func runAgent(cfg cliConfig, stdout io.Writer, stderr io.Writer) int {
for {
// Apply transcript hygiene before sending to the API when -debug is off
hygienic := applyTranscriptHygiene(messages, cfg.debug)
req := oai.ChatCompletionsRequest{
req := oai.ChatCompletionsRequest{
Model: cfg.model,
Messages: hygienic,
}
Expand All @@ -203,11 +203,19 @@ func runAgent(cfg cliConfig, stdout io.Writer, stderr io.Writer) int {
req.ToolChoice = "auto"
}

// Include MaxTokens only when a positive completionCap is set.
// Include MaxTokens only when a positive completionCap is set.
if completionCap > 0 {
req.MaxTokens = completionCap
}

// Enforce prompt to fit the context window, leaving room for completionCap
// Compute prompt budget and trim deterministically when needed
window := oai.ContextWindowForModel(cfg.model)
promptBudget := oai.PromptTokenBudget(window, req.MaxTokens)
if oai.EstimateTokens(req.Messages) > promptBudget {
req.Messages = oai.TrimMessagesToFit(req.Messages, promptBudget)
}

// Pre-flight validate message sequence to avoid API 400s for stray tool messages
if err := oai.ValidateMessageSequence(req.Messages); err != nil {
safeFprintf(stderr, "error: %v\n", err)
Expand Down Expand Up @@ -276,7 +284,7 @@ func runAgent(cfg cliConfig, stdout io.Writer, stderr io.Writer) int {
callCtx, cancel = context.WithTimeout(context.Background(), cfg.httpTimeout)
}

// Fallback: non-streaming request
// Fallback: non-streaming request
resp, err := httpClient.CreateChatCompletion(callCtx, req)
cancel()
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions internal/oai/context_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,14 @@ func ClampCompletionCap(messages []Message, requestedCap int, window int) int {
}
return requestedCap
}

// PromptTokenBudget returns a safe token budget for the prompt given a
// model context window and a desired completion cap. A small safety margin
// of 32 tokens is reserved for reply/control tokens.
func PromptTokenBudget(window int, completionCap int) int {
budget := window - completionCap - 32
if budget < 1 {
return 1
}
return budget
}
182 changes: 182 additions & 0 deletions internal/oai/trim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package oai

// TrimMessagesToFit reduces a transcript so its estimated tokens do not exceed
// the provided limit. Policy:
// - Pin the first system and developer messages when present.
// - Drop oldest non-pinned messages first until within limit.
// - If only pinned remain and still exceed limit, truncate their content
// proportionally but keep both messages.
// - As a last resort, keep only the newest message, truncated to fit.
func TrimMessagesToFit(in []Message, limit int) []Message {
if limit <= 0 || len(in) == 0 {
return []Message{}
}
estimate := func(msgs []Message) int { return EstimateTokens(msgs) }

// Fast path: already fits
if estimate(in) <= limit {
return in
}

out := append([]Message(nil), in...)

// Drop oldest non-pinned messages until within limit.
for len(out) > 0 && estimate(out) > limit {
// Find first indices of pinned roles in current slice
sysIdx, devIdx := -1, -1
for i := range out {
if sysIdx == -1 && out[i].Role == RoleSystem {
sysIdx = i
}
if devIdx == -1 && out[i].Role == RoleDeveloper {
devIdx = i
}
if sysIdx != -1 && devIdx != -1 {
break
}
}
// Remove first non-pinned from the front if any
removed := false
for j := 0; j < len(out); j++ {
if j != sysIdx && j != devIdx {
out = append(out[:j], out[j+1:]...)
removed = true
break
}
}
if !removed {
// Only pinned remain; proceed to truncation
break
}
}

if estimate(out) <= limit {
return out
}

// Truncation path: only pinned remain or still too large
// Identify pinned indices in current slice
sysIdx, devIdx := -1, -1
for i := range out {
if sysIdx == -1 && out[i].Role == RoleSystem {
sysIdx = i
}
if devIdx == -1 && out[i].Role == RoleDeveloper {
devIdx = i
}
}

// If no pinned present, keep newest single message truncated to fit
if sysIdx == -1 && devIdx == -1 {
last := out[len(out)-1]
return []Message{truncateMessageToBudget(last, limit)}
}

cur := estimate(out)
if cur <= limit {
return out
}

// Compute budgets
if sysIdx != -1 && devIdx != -1 {
sysTok := EstimateTokens([]Message{out[sysIdx]})
devTok := EstimateTokens([]Message{out[devIdx]})
totalPinned := sysTok + devTok
if totalPinned == 0 {
totalPinned = 1
}
nonPinned := cur - totalPinned
targetPinned := limit - nonPinned
if targetPinned < 2 { // ensure at least 1 per pinned
targetPinned = 2
}
// Allocate at least 1 token to each, distribute remainder proportionally
minPerPinned := 1
remaining := targetPinned - 2*minPerPinned
if remaining < 0 {
remaining = 0
}
var extraSys, extraDev int
if sysTok+devTok > 0 && remaining > 0 {
extraSys = (sysTok * remaining) / (sysTok + devTok)
extraDev = remaining - extraSys
} else {
extraSys, extraDev = 0, 0
}
targetSys := minPerPinned + extraSys
targetDev := minPerPinned + extraDev
out[sysIdx] = truncateMessageToBudget(out[sysIdx], targetSys)
out[devIdx] = truncateMessageToBudget(out[devIdx], targetDev)
} else if sysIdx != -1 { // only system pinned
// allocate entire limit minus non-system tokens
nonSys := cur - EstimateTokens([]Message{out[sysIdx]})
budget := limit - nonSys
if budget < 1 {
budget = 1
}
out[sysIdx] = truncateMessageToBudget(out[sysIdx], budget)
} else if devIdx != -1 { // only developer pinned
nonDev := cur - EstimateTokens([]Message{out[devIdx]})
budget := limit - nonDev
if budget < 1 {
budget = 1
}
out[devIdx] = truncateMessageToBudget(out[devIdx], budget)
}

// Final guard: if still above limit, drop oldest non-pinned if any; otherwise truncate newest to fit
for estimate(out) > limit {
removed := false
// Try to remove a non-pinned from the front
// Recompute pinned indices
sysIdx, devIdx = -1, -1
for i := range out {
if sysIdx == -1 && out[i].Role == RoleSystem {
sysIdx = i
}
if devIdx == -1 && out[i].Role == RoleDeveloper {
devIdx = i
}
}
for j := 0; j < len(out); j++ {
if j != sysIdx && j != devIdx {
out = append(out[:j], out[j+1:]...)
removed = true
break
}
}
if !removed {
// No non-pinned remain; keep newest one truncated
last := out[len(out)-1]
out = []Message{truncateMessageToBudget(last, limit)}
break
}
}

return out
}

// truncateMessageToBudget returns a copy of msg with content truncated such that
// the single-message token estimate is <= budget (best-effort heuristic).
func truncateMessageToBudget(msg Message, budget int) Message {
if budget <= 1 {
msg.Content = ""
return msg
}
// Binary search on content length, using EstimateTokens heuristic
lo, hi := 0, len(msg.Content)
best := 0
for lo <= hi {
mid := (lo + hi) / 2
test := msg
test.Content = truncate(msg.Content, mid)
if EstimateTokens([]Message{test}) <= budget {
best = mid
lo = mid + 1
} else {
hi = mid - 1
}
}
msg.Content = truncate(msg.Content, best)
return msg
}
85 changes: 85 additions & 0 deletions internal/oai/trim_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package oai

import "testing"

// helper to build a message with role and content
func m(role, content string) Message { return Message{Role: role, Content: content} }

func TestTrimMessagesToFit_PreservesSystemAndDeveloper(t *testing.T) {
sys := m(RoleSystem, repeat("S", 4000)) // ~1000 tokens
dev := m(RoleDeveloper, repeat("D", 4000)) // ~1000 tokens
u1 := m(RoleUser, repeat("u", 4000)) // ~1000 tokens
a1 := m(RoleAssistant, repeat("a", 4000)) // ~1000 tokens
u2 := m(RoleUser, repeat("u", 4000)) // ~1000 tokens
in := []Message{sys, dev, u1, a1, u2}

// Limit so that we cannot keep all messages; must drop from the front (u1,a1)
limit := EstimateTokens(in) - 1500
out := TrimMessagesToFit(in, limit)

if EstimateTokens(out) > limit {
t.Fatalf("trim did not reduce to limit: got=%d limit=%d", EstimateTokens(out), limit)
}
if len(out) >= 2 {
if out[0].Role != RoleSystem {
t.Fatalf("first message should be system; got %q", out[0].Role)
}
if out[1].Role != RoleDeveloper {
t.Fatalf("second message should be developer; got %q", out[1].Role)
}
} else {
t.Fatalf("expected to preserve at least system and developer; got %d", len(out))
}
}

func TestTrimMessagesToFit_DropsOldestNonPinned(t *testing.T) {
sys := m(RoleSystem, "policy")
// 5 alternating user/assistant messages
msgs := []Message{sys}
for i := 0; i < 5; i++ {
msgs = append(msgs, m(RoleUser, repeat("U", 2000)))
msgs = append(msgs, m(RoleAssistant, repeat("A", 2000)))
}
// Force heavy trim
limit := EstimateTokens(msgs) / 2
out := TrimMessagesToFit(msgs, limit)
if EstimateTokens(out) > limit {
t.Fatalf("expected tokens <= limit; got=%d limit=%d", EstimateTokens(out), limit)
}
// Ensure the newest non-pinned message remains (the last assistant)
if out[len(out)-1].Role != RoleAssistant {
t.Fatalf("expected newest assistant at tail; got %q", out[len(out)-1].Role)
}
}

func TestTrimMessagesToFit_OnlySystemDeveloperTooLarge_TruncatesContent(t *testing.T) {
sys := m(RoleSystem, repeat("S", 20000)) // ~5000 tokens
dev := m(RoleDeveloper, repeat("D", 20000)) // ~5000 tokens
in := []Message{sys, dev}
limit := 3000 // far below combined estimate
out := TrimMessagesToFit(in, limit)
if EstimateTokens(out) > limit {
t.Fatalf("expected tokens <= limit after truncation; got=%d limit=%d", EstimateTokens(out), limit)
}
if len(out) != 2 {
t.Fatalf("should keep both system and developer; got %d", len(out))
}
if len(out[0].Content) >= len(sys.Content) {
t.Fatalf("system content was not truncated")
}
if len(out[1].Content) >= len(dev.Content) {
t.Fatalf("developer content was not truncated")
}
}

// repeat returns a string consisting of count repetitions of s.
func repeat(s string, count int) string {
if count <= 0 {
return ""
}
b := make([]byte, 0, len(s)*count)
for i := 0; i < count; i++ {
b = append(b, s...)
}
return string(b)
}