Artem Sukhodolskyi 820879b79f initial commit
2025-10-21 12:55:17 +02:00

482 lines
12 KiB
Go

package app
import (
"bytes"
"context"
"embed"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"log"
"log/slog"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"text/template"
"time"
"github.com/gorilla/csrf"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
)
//go:embed templates
var tplFolder embed.FS
// Config holds server configuration
type Config struct {
Host string
Port string
CSRFKey []byte
TrustedOrigins []string
MaxWidth int
MaxHeight int
SSEFlushInterval time.Duration
WriteTimeout time.Duration
ReadTimeout time.Duration
ShutdownTimeout time.Duration
DB string
}
// DefaultConfig returns default server configuration
func DefaultConfig() *Config {
return &Config{
Host: "localhost",
Port: "5002",
CSRFKey: []byte("0e6139e71f1972259e4b2ce6464b80a3"), // Should be from env in production
TrustedOrigins: []string{"localhost:5002"},
MaxWidth: 1024,
MaxHeight: 1024,
SSEFlushInterval: 200 * time.Millisecond,
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
ShutdownTimeout: 10 * time.Second,
DB: "points.db?_journal_mode=WAL&cache=shared&_foreign_keys=0&synchronous=1",
}
}
type CellState struct {
X int `json:"x"`
Y int `json:"y"`
State bool `json:"state"`
}
// SSEConnection represents an SSE connection
type SSEConnection struct {
ID string
Channel chan CellState
Area Area
LastSent time.Time
}
// Server encapsulates the HTTP server and its dependencies
type Server struct {
config *Config
logger *slog.Logger
templates *template.Template
data *SpatialHashMap
sseConnections map[string]*SSEConnection
router *mux.Router
httpServer *http.Server
}
// NewServer creates a new server instance
func NewServer(config *Config, spatialMap *SpatialHashMap) (*Server, error) {
if config == nil {
config = DefaultConfig()
}
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
templates, err := template.ParseFS(tplFolder, "templates/*.html")
if err != nil {
log.Fatal(err)
}
server := &Server{
config: config,
logger: logger,
templates: templates,
data: spatialMap,
sseConnections: make(map[string]*SSEConnection),
}
server.setupRoutes()
return server, nil
}
// setupRoutes configures the HTTP routes and middleware
func (s *Server) setupRoutes() {
s.router = mux.NewRouter()
// Middleware chain
s.router.Use(s.loggingMiddleware)
s.router.Use(s.corsMiddleware)
s.router.Use(s.csrfMiddleware)
// Routes
s.router.HandleFunc("/api/set-pos", s.setPositionHandler).Methods("POST")
s.router.HandleFunc("/sse/{params}", s.streamEventsHandler).Methods("GET")
s.router.HandleFunc("/", s.serveIndexHandler).Methods("GET")
}
// Middleware functions
func (s *Server) loggingMiddleware(next http.Handler) http.Handler {
return handlers.LoggingHandler(os.Stdout, next)
}
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) csrfMiddleware(next http.Handler) http.Handler {
csrfProtection := csrf.Protect(
s.config.CSRFKey,
csrf.Secure(false), // Set to true in production with HTTPS
csrf.HttpOnly(true),
csrf.TrustedOrigins(s.config.TrustedOrigins),
)
return csrfProtection(next)
}
// HTTP Handlers
func (s *Server) setPositionHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var cellState CellState
if err := json.NewDecoder(r.Body).Decode(&cellState); err != nil {
s.logger.Error("Failed to decode JSON", "error", err)
s.writeErrorResponse(w, "Invalid JSON", http.StatusBadRequest)
return
}
// Update spatial data
if cellState.State {
s.data.Insert(cellState.X, cellState.Y)
} else {
s.data.Remove(cellState.X, cellState.Y)
}
// Notify SSE clients
s.notifySSEClients(cellState)
response := map[string]string{"status": "ok"}
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("Failed to encode response", "error", err)
}
}
func (s *Server) streamEventsHandler(w http.ResponseWriter, r *http.Request) {
params, err := s.parseSSEParams(r)
if err != nil {
s.logger.Error("Invalid SSE parameters", "error", err)
s.writeErrorResponse(w, "Invalid parameters", http.StatusBadRequest)
return
}
area := Area{
X: params[0],
Y: params[1],
Width: params[2],
Height: params[3],
}
if err := s.validateArea(area); err != nil {
s.logger.Error("Invalid SSE area", "error", err)
s.writeErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
s.writeErrorResponse(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
client := s.createSSEClient(area)
defer s.removeSSEClient(client.ID)
s.handleSSEConnection(w, r, client, flusher)
}
func (s *Server) serveIndexHandler(w http.ResponseWriter, r *http.Request) {
data := map[string]any{
csrf.TemplateTag: csrf.TemplateField(r),
}
if err := s.templates.ExecuteTemplate(w, "index.html", data); err != nil {
s.logger.Error("Failed to execute template", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
func (s *Server) parseSSEParams(r *http.Request) ([]int, error) {
vars := mux.Vars(r)
path := vars["params"]
paramStrs := strings.Split(path, ",")
if len(paramStrs) != 4 {
return nil, fmt.Errorf("expected 4 parameters, got %d", len(paramStrs))
}
params := make([]int, 4)
for i, paramStr := range paramStrs {
var err error
params[i], err = strconv.Atoi(paramStr)
if err != nil {
return nil, fmt.Errorf("invalid parameter %d: %s", i, paramStr)
}
}
return params, nil
}
func (s *Server) validateArea(area Area) error {
if area.Width <= 0 || area.Height <= 0 {
return fmt.Errorf("width and height must be positive")
}
if area.Width > s.config.MaxWidth || area.Height > s.config.MaxHeight {
return fmt.Errorf("area dimensions exceed maximum bounds")
}
return nil
}
func (s *Server) createSSEClient(area Area) *SSEConnection {
clientID := fmt.Sprintf("%d,%d,%d,%d", area.X, area.Y, area.Width, area.Height)
client := &SSEConnection{
ID: clientID,
Channel: make(chan CellState, 10),
Area: area,
LastSent: time.Unix(0, 0),
}
s.sseConnections[clientID] = client
return client
}
func (s *Server) removeSSEClient(clientID string) {
if client, exists := s.sseConnections[clientID]; exists {
close(client.Channel)
delete(s.sseConnections, clientID)
s.logger.Info("SSE client disconnected", "client", clientID)
}
}
func (s *Server) notifySSEClients(cellState CellState) {
for _, client := range s.sseConnections {
if client.Area.Contains(cellState.X, cellState.Y) {
select {
case client.Channel <- cellState:
default:
// Channel is full, skip this update
s.logger.Warn("SSE client channel full, skipping update", "client", client.ID)
}
}
}
}
func (s *Server) handleSSEConnection(w http.ResponseWriter, r *http.Request, client *SSEConnection, flusher http.Flusher) {
clientGone := r.Context().Done()
if err := s.sendPointsState(w, client, flusher); err != nil {
s.logger.Error("Failed to send points state", "error", err, "client", client.ID)
return
}
for {
select {
case <-clientGone:
s.logger.Info("SSE connection closed", "client", client.ID)
return
case cellState := <-client.Channel:
// Throttle updates
if since := time.Since(client.LastSent); since < s.config.SSEFlushInterval {
select {
case client.Channel <- cellState:
default:
// Channel is full, skip this update
s.logger.Warn("SSE client channel full, skipping update", "client", client.ID)
}
time.Sleep(s.config.SSEFlushInterval - since)
continue
}
buffer := new(bytes.Buffer)
binary.Write(buffer, binary.LittleEndian, int32(cellState.X))
binary.Write(buffer, binary.LittleEndian, int32(cellState.Y))
binary.Write(buffer, binary.LittleEndian, boolToInt(cellState.State))
datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes())
if _, err := fmt.Fprintf(w, "event: update\ndata: %s\n\n", datab64); err != nil {
}
flusher.Flush()
client.LastSent = time.Now()
}
}
}
func (s *Server) sendPointsState(w http.ResponseWriter, client *SSEConnection, flusher http.Flusher) error {
area := client.Area
points := s.data.GetLocalPoints(area)
buffer := new(bytes.Buffer)
for _, point := range points {
binary.Write(buffer, binary.LittleEndian, int32(point.X))
binary.Write(buffer, binary.LittleEndian, int32(point.Y))
}
datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes())
if _, err := fmt.Fprintf(w, "event: state\ndata: %s\n\n", datab64); err != nil {
return err
}
flusher.Flush()
return nil
}
func (s *Server) writeErrorResponse(w http.ResponseWriter, message string, statusCode int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
response := map[string]string{"error": message}
json.NewEncoder(w).Encode(response)
}
// Server lifecycle methods
func (s *Server) Start() error {
addr := fmt.Sprintf("%s:%s", s.config.Host, s.config.Port)
s.httpServer = &http.Server{
Handler: s.router,
Addr: addr,
WriteTimeout: s.config.WriteTimeout,
ReadTimeout: s.config.ReadTimeout,
}
s.logger.Info("Server starting", "address", addr)
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("server failed to start: %w", err)
}
return nil
}
func (s *Server) Stop(ctx context.Context) error {
s.logger.Info("Server shutting down...")
// Close all SSE connections
for _, client := range s.sseConnections {
close(client.Channel)
}
s.sseConnections = make(map[string]*SSEConnection)
if s.httpServer != nil {
err := s.httpServer.Shutdown(ctx)
if err != nil {
s.logger.Error("Failed to shutdown server", "error", err)
return fmt.Errorf("server shutdown error: %w", err)
}
}
if s.data != nil {
s.logger.Info("Closing spatial hash map")
err := s.data.Close()
if err != nil {
s.logger.Error("Failed to close spatial hash map", "error", err)
return fmt.Errorf("failed to close spatial hash map: %w", err)
}
}
s.logger.Info("Server stopped gracefully")
return nil
}
// Run starts the server with graceful shutdown handling
func Run() error {
return RunWithConfig(nil, nil)
}
// RunWithConfig starts the server with custom configuration and spatial map
func RunWithConfig(config *Config, spatialMap *SpatialHashMap) (err error) {
if config == nil {
config = DefaultConfig()
}
if spatialMap == nil {
spatialMap, err = NewSpatialHashMap(config.DB) // Assumes this function exists
if err != nil {
return fmt.Errorf("failed to create spatial hash map: %w", err)
}
err = spatialMap.InitSchema()
if err != nil {
return fmt.Errorf("failed to create spatial hash map: %w", err)
}
}
defer spatialMap.Close()
server, err := NewServer(config, spatialMap)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
// Setup graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Start server in goroutine
errChan := make(chan error, 1)
go func() {
errChan <- server.Start()
}()
// Wait for shutdown signal or server error
select {
case err := <-errChan:
if err != nil {
return err
}
case <-sigChan:
server.logger.Info("Shutdown signal received")
ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout)
defer cancel()
if err := server.Stop(ctx); err != nil {
server.logger.Error("Server shutdown error", "error", err)
return err
}
server.logger.Info("Server stopped gracefully")
}
return nil
}