package app import ( "bytes" "context" "embed" "encoding/base64" "encoding/binary" "encoding/json" "fmt" "log" "log/slog" "net/http" "os" "os/signal" "strconv" "strings" "syscall" "text/template" "time" "github.com/gorilla/csrf" "github.com/gorilla/handlers" "github.com/gorilla/mux" ) //go:embed templates var tplFolder embed.FS // Config holds server configuration type Config struct { Host string Port string CSRFKey []byte TrustedOrigins []string MaxWidth int MaxHeight int SSEFlushInterval time.Duration WriteTimeout time.Duration ReadTimeout time.Duration ShutdownTimeout time.Duration DB string } // DefaultConfig returns default server configuration func DefaultConfig() *Config { return &Config{ Host: "localhost", Port: "5002", CSRFKey: []byte("0e6139e71f1972259e4b2ce6464b80a3"), // Should be from env in production TrustedOrigins: []string{"localhost:5002"}, MaxWidth: 1024, MaxHeight: 1024, SSEFlushInterval: 200 * time.Millisecond, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, ShutdownTimeout: 10 * time.Second, DB: "points.db?_journal_mode=WAL&cache=shared&_foreign_keys=0&synchronous=1", } } type CellState struct { X int `json:"x"` Y int `json:"y"` State bool `json:"state"` } // SSEConnection represents an SSE connection type SSEConnection struct { ID string Channel chan CellState Area Area LastSent time.Time } // Server encapsulates the HTTP server and its dependencies type Server struct { config *Config logger *slog.Logger templates *template.Template data *SpatialHashMap sseConnections map[string]*SSEConnection router *mux.Router httpServer *http.Server } // NewServer creates a new server instance func NewServer(config *Config, spatialMap *SpatialHashMap) (*Server, error) { if config == nil { config = DefaultConfig() } logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelInfo, })) templates, err := template.ParseFS(tplFolder, "templates/*.html") if err != nil { log.Fatal(err) } server := &Server{ config: config, logger: logger, templates: templates, data: spatialMap, sseConnections: make(map[string]*SSEConnection), } server.setupRoutes() return server, nil } // setupRoutes configures the HTTP routes and middleware func (s *Server) setupRoutes() { s.router = mux.NewRouter() // Middleware chain s.router.Use(s.loggingMiddleware) s.router.Use(s.corsMiddleware) s.router.Use(s.csrfMiddleware) // Routes s.router.HandleFunc("/api/set-pos", s.setPositionHandler).Methods("POST") s.router.HandleFunc("/sse/{params}", s.streamEventsHandler).Methods("GET") s.router.HandleFunc("/", s.serveIndexHandler).Methods("GET") } // Middleware functions func (s *Server) loggingMiddleware(next http.Handler) http.Handler { return handlers.LoggingHandler(os.Stdout, next) } func (s *Server) corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } func (s *Server) csrfMiddleware(next http.Handler) http.Handler { csrfProtection := csrf.Protect( s.config.CSRFKey, csrf.Secure(false), // Set to true in production with HTTPS csrf.HttpOnly(true), csrf.TrustedOrigins(s.config.TrustedOrigins), ) return csrfProtection(next) } // HTTP Handlers func (s *Server) setPositionHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var cellState CellState if err := json.NewDecoder(r.Body).Decode(&cellState); err != nil { s.logger.Error("Failed to decode JSON", "error", err) s.writeErrorResponse(w, "Invalid JSON", http.StatusBadRequest) return } // Update spatial data if cellState.State { s.data.Insert(cellState.X, cellState.Y) } else { s.data.Remove(cellState.X, cellState.Y) } // Notify SSE clients s.notifySSEClients(cellState) response := map[string]string{"status": "ok"} if err := json.NewEncoder(w).Encode(response); err != nil { s.logger.Error("Failed to encode response", "error", err) } } func (s *Server) streamEventsHandler(w http.ResponseWriter, r *http.Request) { params, err := s.parseSSEParams(r) if err != nil { s.logger.Error("Invalid SSE parameters", "error", err) s.writeErrorResponse(w, "Invalid parameters", http.StatusBadRequest) return } area := Area{ X: params[0], Y: params[1], Width: params[2], Height: params[3], } if err := s.validateArea(area); err != nil { s.logger.Error("Invalid SSE area", "error", err) s.writeErrorResponse(w, err.Error(), http.StatusBadRequest) return } // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") flusher, ok := w.(http.Flusher) if !ok { s.writeErrorResponse(w, "Streaming unsupported", http.StatusInternalServerError) return } client := s.createSSEClient(area) defer s.removeSSEClient(client.ID) s.handleSSEConnection(w, r, client, flusher) } func (s *Server) serveIndexHandler(w http.ResponseWriter, r *http.Request) { data := map[string]any{ csrf.TemplateTag: csrf.TemplateField(r), } if err := s.templates.ExecuteTemplate(w, "index.html", data); err != nil { s.logger.Error("Failed to execute template", "error", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } func (s *Server) parseSSEParams(r *http.Request) ([]int, error) { vars := mux.Vars(r) path := vars["params"] paramStrs := strings.Split(path, ",") if len(paramStrs) != 4 { return nil, fmt.Errorf("expected 4 parameters, got %d", len(paramStrs)) } params := make([]int, 4) for i, paramStr := range paramStrs { var err error params[i], err = strconv.Atoi(paramStr) if err != nil { return nil, fmt.Errorf("invalid parameter %d: %s", i, paramStr) } } return params, nil } func (s *Server) validateArea(area Area) error { if area.Width <= 0 || area.Height <= 0 { return fmt.Errorf("width and height must be positive") } if area.Width > s.config.MaxWidth || area.Height > s.config.MaxHeight { return fmt.Errorf("area dimensions exceed maximum bounds") } return nil } func (s *Server) createSSEClient(area Area) *SSEConnection { clientID := fmt.Sprintf("%d,%d,%d,%d", area.X, area.Y, area.Width, area.Height) client := &SSEConnection{ ID: clientID, Channel: make(chan CellState, 10), Area: area, LastSent: time.Unix(0, 0), } s.sseConnections[clientID] = client return client } func (s *Server) removeSSEClient(clientID string) { if client, exists := s.sseConnections[clientID]; exists { close(client.Channel) delete(s.sseConnections, clientID) s.logger.Info("SSE client disconnected", "client", clientID) } } func (s *Server) notifySSEClients(cellState CellState) { for _, client := range s.sseConnections { if client.Area.Contains(cellState.X, cellState.Y) { select { case client.Channel <- cellState: default: // Channel is full, skip this update s.logger.Warn("SSE client channel full, skipping update", "client", client.ID) } } } } func (s *Server) handleSSEConnection(w http.ResponseWriter, r *http.Request, client *SSEConnection, flusher http.Flusher) { clientGone := r.Context().Done() if err := s.sendPointsState(w, client, flusher); err != nil { s.logger.Error("Failed to send points state", "error", err, "client", client.ID) return } for { select { case <-clientGone: s.logger.Info("SSE connection closed", "client", client.ID) return case cellState := <-client.Channel: // Throttle updates if since := time.Since(client.LastSent); since < s.config.SSEFlushInterval { select { case client.Channel <- cellState: default: // Channel is full, skip this update s.logger.Warn("SSE client channel full, skipping update", "client", client.ID) } time.Sleep(s.config.SSEFlushInterval - since) continue } buffer := new(bytes.Buffer) binary.Write(buffer, binary.LittleEndian, int32(cellState.X)) binary.Write(buffer, binary.LittleEndian, int32(cellState.Y)) binary.Write(buffer, binary.LittleEndian, boolToInt(cellState.State)) datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes()) if _, err := fmt.Fprintf(w, "event: update\ndata: %s\n\n", datab64); err != nil { } flusher.Flush() client.LastSent = time.Now() } } } func (s *Server) sendPointsState(w http.ResponseWriter, client *SSEConnection, flusher http.Flusher) error { area := client.Area points := s.data.GetLocalPoints(area) buffer := new(bytes.Buffer) for _, point := range points { binary.Write(buffer, binary.LittleEndian, int32(point.X)) binary.Write(buffer, binary.LittleEndian, int32(point.Y)) } datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes()) if _, err := fmt.Fprintf(w, "event: state\ndata: %s\n\n", datab64); err != nil { return err } flusher.Flush() return nil } func (s *Server) writeErrorResponse(w http.ResponseWriter, message string, statusCode int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) response := map[string]string{"error": message} json.NewEncoder(w).Encode(response) } // Server lifecycle methods func (s *Server) Start() error { addr := fmt.Sprintf("%s:%s", s.config.Host, s.config.Port) s.httpServer = &http.Server{ Handler: s.router, Addr: addr, WriteTimeout: s.config.WriteTimeout, ReadTimeout: s.config.ReadTimeout, } s.logger.Info("Server starting", "address", addr) if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { return fmt.Errorf("server failed to start: %w", err) } return nil } func (s *Server) Stop(ctx context.Context) error { s.logger.Info("Server shutting down...") // Close all SSE connections for _, client := range s.sseConnections { close(client.Channel) } s.sseConnections = make(map[string]*SSEConnection) if s.httpServer != nil { err := s.httpServer.Shutdown(ctx) if err != nil { s.logger.Error("Failed to shutdown server", "error", err) return fmt.Errorf("server shutdown error: %w", err) } } if s.data != nil { s.logger.Info("Closing spatial hash map") err := s.data.Close() if err != nil { s.logger.Error("Failed to close spatial hash map", "error", err) return fmt.Errorf("failed to close spatial hash map: %w", err) } } s.logger.Info("Server stopped gracefully") return nil } // Run starts the server with graceful shutdown handling func Run() error { return RunWithConfig(nil, nil) } // RunWithConfig starts the server with custom configuration and spatial map func RunWithConfig(config *Config, spatialMap *SpatialHashMap) (err error) { if config == nil { config = DefaultConfig() } if spatialMap == nil { spatialMap, err = NewSpatialHashMap(config.DB) // Assumes this function exists if err != nil { return fmt.Errorf("failed to create spatial hash map: %w", err) } err = spatialMap.InitSchema() if err != nil { return fmt.Errorf("failed to create spatial hash map: %w", err) } } defer spatialMap.Close() server, err := NewServer(config, spatialMap) if err != nil { return fmt.Errorf("failed to create server: %w", err) } // Setup graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) // Start server in goroutine errChan := make(chan error, 1) go func() { errChan <- server.Start() }() // Wait for shutdown signal or server error select { case err := <-errChan: if err != nil { return err } case <-sigChan: server.logger.Info("Shutdown signal received") ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout) defer cancel() if err := server.Stop(ctx); err != nil { server.logger.Error("Server shutdown error", "error", err) return err } server.logger.Info("Server stopped gracefully") } return nil }