Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,092 changes: 1,092 additions & 0 deletions CLAUDE.md

Large diffs are not rendered by default.

19 changes: 11 additions & 8 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,18 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
processExitCh := make(chan error, 1)
go func() {
defer close(processExitCh)
if err := process.Wait(); err != nil {
if errors.Is(err, termexec.ErrNonZeroExitCode) {
processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err)
} else {
processExitCh <- xerrors.Errorf("failed to wait for process: %w", err)
// Only wait for process if it exists (not in --print-openapi mode)
if process != nil {
if err := process.Wait(); err != nil {
if errors.Is(err, termexec.ErrNonZeroExitCode) {
processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err)
} else {
processExitCh <- xerrors.Errorf("failed to wait for process: %w", err)
}
}
if err := srv.Stop(ctx); err != nil {
logger.Error("Failed to stop server", "error", err)
}
}
if err := srv.Stop(ctx); err != nil {
logger.Error("Failed to stop server", "error", err)
}
}()
if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed {
Expand Down
38 changes: 35 additions & 3 deletions lib/httpapi/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpapi

import (
"fmt"
"log/slog"
"strings"
"sync"
"time"
Expand All @@ -12,6 +13,15 @@ import (
"github.com/danielgtaylor/huma/v2"
)

// SubscriberLimitError is returned when the maximum number of SSE subscribers is reached
type SubscriberLimitError struct {
Limit int
}

func (e *SubscriberLimitError) Error() string {
return fmt.Sprintf("subscriber limit reached: %d", e.Limit)
}

type EventType string

const (
Expand Down Expand Up @@ -77,7 +87,9 @@ func convertStatus(status st.ConversationStatus) AgentStatus {
case st.ConversationStatusChanging:
return AgentStatusRunning
default:
panic(fmt.Sprintf("unknown conversation status: %s", status))
// Don't panic - log and return safe default
slog.Error("Unknown conversation status", "status", status)
return AgentStatusRunning
}
}

Expand All @@ -86,6 +98,13 @@ func convertStatus(status st.ConversationStatus) AgentStatus {
// Listeners must actively drain the channel, so it's important to
// set this to a value that is large enough to handle the expected
// number of events.

const (
// MaxSSESubscribers limits the number of concurrent SSE connections
// to prevent resource exhaustion attacks
MaxSSESubscribers = 100
)

func NewEventEmitter(subscriptionBufSize int) *EventEmitter {
return &EventEmitter{
mu: sync.Mutex{},
Expand Down Expand Up @@ -115,6 +134,9 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) {
default:
// If the channel is full, close it.
// Listeners must actively drain the channel.
slog.Warn("Closing slow SSE subscriber - channel buffer full",
"subscriberId", chanId,
"bufferSize", e.subscriptionBufSize)
e.unsubscribeInner(chanId)
}
}
Expand Down Expand Up @@ -199,16 +221,26 @@ func (e *EventEmitter) currentStateAsEvents() []Event {
// - a subscription ID that can be used to unsubscribe.
// - a channel for receiving events.
// - a list of events that allow to recreate the state of the conversation right before the subscription was created.
func (e *EventEmitter) Subscribe() (int, <-chan Event, []Event) {
// - an error if the maximum number of subscribers has been reached.
func (e *EventEmitter) Subscribe() (int, <-chan Event, []Event, error) {
e.mu.Lock()
defer e.mu.Unlock()

// Check subscriber limit to prevent resource exhaustion
if len(e.chans) >= MaxSSESubscribers {
slog.Warn("SSE subscriber limit reached - rejecting new connection",
"limit", MaxSSESubscribers,
"current", len(e.chans))
return 0, nil, nil, &SubscriberLimitError{Limit: MaxSSESubscribers}
}

stateEvents := e.currentStateAsEvents()

// Once a channel becomes full, it will be closed.
ch := make(chan Event, e.subscriptionBufSize)
e.chans[e.chanIdx] = ch
e.chanIdx++
return e.chanIdx - 1, ch, stateEvents
return e.chanIdx - 1, ch, stateEvents, nil
}

// Assumes the caller holds the lock.
Expand Down
22 changes: 19 additions & 3 deletions lib/httpapi/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import (
func TestEventEmitter(t *testing.T) {
t.Run("single-subscription", func(t *testing.T) {
emitter := NewEventEmitter(10)
_, ch, stateEvents := emitter.Subscribe()
_, ch, stateEvents, err := emitter.Subscribe()
assert.NoError(t, err)
assert.Empty(t, ch)
assert.Equal(t, []Event{
{
Expand Down Expand Up @@ -64,7 +65,8 @@ func TestEventEmitter(t *testing.T) {
emitter := NewEventEmitter(10)
channels := make([]<-chan Event, 0, 10)
for i := 0; i < 10; i++ {
_, ch, _ := emitter.Subscribe()
_, ch, _, err := emitter.Subscribe()
assert.NoError(t, err)
channels = append(channels, ch)
}
now := time.Now()
Expand All @@ -83,7 +85,8 @@ func TestEventEmitter(t *testing.T) {

t.Run("close-channel", func(t *testing.T) {
emitter := NewEventEmitter(1)
_, ch, _ := emitter.Subscribe()
_, ch, _, err := emitter.Subscribe()
assert.NoError(t, err)
for i := range 5 {
emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{
{Id: i, Message: fmt.Sprintf("Hello, world! %d", i), Role: st.ConversationRoleUser, Time: time.Now()},
Expand All @@ -98,4 +101,17 @@ func TestEventEmitter(t *testing.T) {
t.Fatalf("read should not block")
}
})

t.Run("subscriber-limit", func(t *testing.T) {
emitter := NewEventEmitter(10)
// Subscribe up to the limit
for i := 0; i < MaxSSESubscribers; i++ {
_, _, _, err := emitter.Subscribe()
assert.NoError(t, err, "subscription %d should succeed", i)
}
// Next subscription should fail
_, _, _, err := emitter.Subscribe()
assert.Error(t, err)
assert.IsType(t, &SubscriberLimitError{}, err)
})
}
68 changes: 60 additions & 8 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ func (s *Server) GetOpenAPI() string {
// because the action of taking a snapshot takes time too.
const snapshotInterval = 25 * time.Millisecond

const (
// MaxMessageSize is the maximum size for user messages (10MB)
MaxMessageSize = 10 * 1024 * 1024
// MaxRawMessageSize is the maximum size for raw terminal input (1KB)
// Raw messages go directly to the terminal, so we're more conservative
MaxRawMessageSize = 1024
)

type ServerConfig struct {
AgentType mf.AgentType
Process *termexec.Process
Expand All @@ -107,11 +115,13 @@ type ServerConfig struct {
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
// Viper/Cobra use different separators (space for env vars, comma for flags),
// so these characters likely indicate user error.
func parseAllowedHosts(input []string) ([]string, error) {
func parseAllowedHosts(input []string, logger *slog.Logger) ([]string, error) {
if len(input) == 0 {
return nil, fmt.Errorf("the list must not be empty")
}
if slices.Contains(input, "*") {
logger.Warn("⚠️ SECURITY WARNING: Host wildcard '*' allows requests from ANY host",
"recommendation", "Only use '*' in development. In production, specify exact hosts.")
return []string{"*"}, nil
}
// First pass: whitespace & comma checks (surface these errors first)
Expand Down Expand Up @@ -157,11 +167,13 @@ func parseAllowedHosts(input []string) ([]string, error) {
}

// Validate allowed origins
func parseAllowedOrigins(input []string) ([]string, error) {
func parseAllowedOrigins(input []string, logger *slog.Logger) ([]string, error) {
if len(input) == 0 {
return nil, fmt.Errorf("the list must not be empty")
}
if slices.Contains(input, "*") {
logger.Warn("⚠️ SECURITY WARNING: CORS wildcard '*' allows requests from ANY website",
"recommendation", "Only use '*' in development. In production, specify exact origins.")
return []string{"*"}, nil
}
// Viper/Cobra use different separators (space for env vars, comma for flags),
Expand Down Expand Up @@ -194,11 +206,11 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {

logger := logctx.From(ctx)

allowedHosts, err := parseAllowedHosts(config.AllowedHosts)
allowedHosts, err := parseAllowedHosts(config.AllowedHosts, logger)
if err != nil {
return nil, xerrors.Errorf("failed to parse allowed hosts: %w", err)
}
allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins)
allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins, logger)
if err != nil {
return nil, xerrors.Errorf("failed to parse allowed origins: %w", err)
}
Expand Down Expand Up @@ -321,7 +333,19 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) {
func (s *Server) StartSnapshotLoop(ctx context.Context) {
s.conversation.StartSnapshotLoop(ctx)
go func() {
ticker := time.NewTicker(snapshotInterval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
s.logger.Info("Snapshot loop exiting")
return
case <-ticker.C:
s.emitter.UpdateStatusAndEmitChanges(s.conversation.Status())
s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages())
s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen())
}
currentStatus := s.conversation.Status()

// Send initial prompt when agent becomes stable for the first time
Expand Down Expand Up @@ -431,6 +455,18 @@ func (s *Server) getMessages(ctx context.Context, input *struct{}) (*MessagesRes

// createMessage handles POST /message
func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*MessageResponse, error) {
// Validate message size based on type
maxSize := MaxMessageSize
if input.Body.Type == MessageTypeRaw {
maxSize = MaxRawMessageSize
}

if len(input.Body.Content) > maxSize {
return nil, huma.Error400BadRequest(
fmt.Sprintf("message too large (max %d bytes)", maxSize),
)
}

s.mu.Lock()
defer s.mu.Unlock()

Expand Down Expand Up @@ -497,7 +533,13 @@ func (s *Server) uploadFiles(ctx context.Context, input *struct {

// subscribeEvents is an SSE endpoint that sends events to the client
func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.Sender) {
subscriberId, ch, stateEvents := s.emitter.Subscribe()
subscriberId, ch, stateEvents, err := s.emitter.Subscribe()
if err != nil {
s.logger.Error("Failed to subscribe", "error", err)
// Send error to client and close connection
_ = send.Data(map[string]string{"error": err.Error()})
return
}
defer s.emitter.Unsubscribe(subscriberId)
s.logger.Info("New subscriber", "subscriberId", subscriberId)
for _, event := range stateEvents {
Expand Down Expand Up @@ -532,7 +574,13 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.
}

func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse.Sender) {
subscriberId, ch, stateEvents := s.emitter.Subscribe()
subscriberId, ch, stateEvents, err := s.emitter.Subscribe()
if err != nil {
s.logger.Error("Failed to subscribe to screen", "error", err)
// Send error to client and close connection
_ = send.Data(map[string]string{"error": err.Error()})
return
}
defer s.emitter.Unsubscribe(subscriberId)
s.logger.Info("New screen subscriber", "subscriberId", subscriberId)
for _, event := range stateEvents {
Expand Down Expand Up @@ -569,8 +617,12 @@ func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse.
func (s *Server) Start() error {
addr := fmt.Sprintf(":%d", s.port)
s.srv = &http.Server{
Addr: addr,
Handler: s.router,
Addr: addr,
Handler: s.router,
ReadTimeout: 15 * time.Second, // Prevent slow header attacks
WriteTimeout: 0, // Disabled for SSE long-polling
IdleTimeout: 60 * time.Second, // Close idle connections
ReadHeaderTimeout: 5 * time.Second, // Specifically for headers
}

return s.srv.ListenAndServe()
Expand Down
5 changes: 3 additions & 2 deletions lib/logctx/logctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ func WithLogger(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, loggerKey, logger)
}

// From retrieves the logger from the context or panics if no logger is found
// From retrieves the logger from the context or returns the default logger if none is found
func From(ctx context.Context) *slog.Logger {
if logger, ok := ctx.Value(loggerKey).(*slog.Logger); ok {
return logger
}
panic("no logger found in context")
// Return default logger instead of panicking
return slog.Default()
}

// plucked from log/slog
Expand Down
Loading