482 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			482 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package app
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"embed"
 | 
						|
	"encoding/base64"
 | 
						|
	"encoding/binary"
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"log"
 | 
						|
	"log/slog"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"os/signal"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"syscall"
 | 
						|
	"text/template"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/gorilla/csrf"
 | 
						|
	"github.com/gorilla/handlers"
 | 
						|
	"github.com/gorilla/mux"
 | 
						|
)
 | 
						|
 | 
						|
//go:embed templates
 | 
						|
var tplFolder embed.FS
 | 
						|
 | 
						|
// Config holds server configuration
 | 
						|
type Config struct {
 | 
						|
	Host             string
 | 
						|
	Port             string
 | 
						|
	CSRFKey          []byte
 | 
						|
	TrustedOrigins   []string
 | 
						|
	MaxWidth         int
 | 
						|
	MaxHeight        int
 | 
						|
	SSEFlushInterval time.Duration
 | 
						|
	WriteTimeout     time.Duration
 | 
						|
	ReadTimeout      time.Duration
 | 
						|
	ShutdownTimeout  time.Duration
 | 
						|
	DB               string
 | 
						|
}
 | 
						|
 | 
						|
// DefaultConfig returns default server configuration
 | 
						|
func DefaultConfig() *Config {
 | 
						|
	return &Config{
 | 
						|
		Host:             "localhost",
 | 
						|
		Port:             "5002",
 | 
						|
		CSRFKey:          []byte("0e6139e71f1972259e4b2ce6464b80a3"), // Should be from env in production
 | 
						|
		TrustedOrigins:   []string{"localhost:5002"},
 | 
						|
		MaxWidth:         1024,
 | 
						|
		MaxHeight:        1024,
 | 
						|
		SSEFlushInterval: 200 * time.Millisecond,
 | 
						|
		WriteTimeout:     15 * time.Second,
 | 
						|
		ReadTimeout:      15 * time.Second,
 | 
						|
		ShutdownTimeout:  10 * time.Second,
 | 
						|
		DB:               "points.db?_journal_mode=WAL&cache=shared&_foreign_keys=0&synchronous=1",
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type CellState struct {
 | 
						|
	X     int  `json:"x"`
 | 
						|
	Y     int  `json:"y"`
 | 
						|
	State bool `json:"state"`
 | 
						|
}
 | 
						|
 | 
						|
// SSEConnection represents an SSE connection
 | 
						|
type SSEConnection struct {
 | 
						|
	ID       string
 | 
						|
	Channel  chan CellState
 | 
						|
	Area     Area
 | 
						|
	LastSent time.Time
 | 
						|
}
 | 
						|
 | 
						|
// Server encapsulates the HTTP server and its dependencies
 | 
						|
type Server struct {
 | 
						|
	config         *Config
 | 
						|
	logger         *slog.Logger
 | 
						|
	templates      *template.Template
 | 
						|
	data           *SpatialHashMap
 | 
						|
	sseConnections map[string]*SSEConnection
 | 
						|
	router         *mux.Router
 | 
						|
	httpServer     *http.Server
 | 
						|
}
 | 
						|
 | 
						|
// NewServer creates a new server instance
 | 
						|
func NewServer(config *Config, spatialMap *SpatialHashMap) (*Server, error) {
 | 
						|
	if config == nil {
 | 
						|
		config = DefaultConfig()
 | 
						|
	}
 | 
						|
 | 
						|
	logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
 | 
						|
		Level: slog.LevelInfo,
 | 
						|
	}))
 | 
						|
 | 
						|
	templates, err := template.ParseFS(tplFolder, "templates/*.html")
 | 
						|
	if err != nil {
 | 
						|
		log.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	server := &Server{
 | 
						|
		config:         config,
 | 
						|
		logger:         logger,
 | 
						|
		templates:      templates,
 | 
						|
		data:           spatialMap,
 | 
						|
		sseConnections: make(map[string]*SSEConnection),
 | 
						|
	}
 | 
						|
 | 
						|
	server.setupRoutes()
 | 
						|
	return server, nil
 | 
						|
}
 | 
						|
 | 
						|
// setupRoutes configures the HTTP routes and middleware
 | 
						|
func (s *Server) setupRoutes() {
 | 
						|
	s.router = mux.NewRouter()
 | 
						|
 | 
						|
	// Middleware chain
 | 
						|
	s.router.Use(s.loggingMiddleware)
 | 
						|
	s.router.Use(s.corsMiddleware)
 | 
						|
	s.router.Use(s.csrfMiddleware)
 | 
						|
 | 
						|
	// Routes
 | 
						|
	s.router.HandleFunc("/api/set-pos", s.setPositionHandler).Methods("POST")
 | 
						|
	s.router.HandleFunc("/sse/{params}", s.streamEventsHandler).Methods("GET")
 | 
						|
	s.router.HandleFunc("/", s.serveIndexHandler).Methods("GET")
 | 
						|
}
 | 
						|
 | 
						|
// Middleware functions
 | 
						|
func (s *Server) loggingMiddleware(next http.Handler) http.Handler {
 | 
						|
	return handlers.LoggingHandler(os.Stdout, next)
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
 | 
						|
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						|
		w.Header().Set("Access-Control-Allow-Origin", "*")
 | 
						|
		w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
 | 
						|
		w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
 | 
						|
 | 
						|
		if r.Method == "OPTIONS" {
 | 
						|
			w.WriteHeader(http.StatusOK)
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		next.ServeHTTP(w, r)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) csrfMiddleware(next http.Handler) http.Handler {
 | 
						|
	csrfProtection := csrf.Protect(
 | 
						|
		s.config.CSRFKey,
 | 
						|
		csrf.Secure(false), // Set to true in production with HTTPS
 | 
						|
		csrf.HttpOnly(true),
 | 
						|
		csrf.TrustedOrigins(s.config.TrustedOrigins),
 | 
						|
	)
 | 
						|
	return csrfProtection(next)
 | 
						|
}
 | 
						|
 | 
						|
// HTTP Handlers
 | 
						|
func (s *Server) setPositionHandler(w http.ResponseWriter, r *http.Request) {
 | 
						|
	w.Header().Set("Content-Type", "application/json")
 | 
						|
 | 
						|
	var cellState CellState
 | 
						|
	if err := json.NewDecoder(r.Body).Decode(&cellState); err != nil {
 | 
						|
		s.logger.Error("Failed to decode JSON", "error", err)
 | 
						|
		s.writeErrorResponse(w, "Invalid JSON", http.StatusBadRequest)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// Update spatial data
 | 
						|
	if cellState.State {
 | 
						|
		s.data.Insert(cellState.X, cellState.Y)
 | 
						|
	} else {
 | 
						|
		s.data.Remove(cellState.X, cellState.Y)
 | 
						|
	}
 | 
						|
 | 
						|
	// Notify SSE clients
 | 
						|
	s.notifySSEClients(cellState)
 | 
						|
 | 
						|
	response := map[string]string{"status": "ok"}
 | 
						|
	if err := json.NewEncoder(w).Encode(response); err != nil {
 | 
						|
		s.logger.Error("Failed to encode response", "error", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) streamEventsHandler(w http.ResponseWriter, r *http.Request) {
 | 
						|
	params, err := s.parseSSEParams(r)
 | 
						|
	if err != nil {
 | 
						|
		s.logger.Error("Invalid SSE parameters", "error", err)
 | 
						|
		s.writeErrorResponse(w, "Invalid parameters", http.StatusBadRequest)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	area := Area{
 | 
						|
		X:      params[0],
 | 
						|
		Y:      params[1],
 | 
						|
		Width:  params[2],
 | 
						|
		Height: params[3],
 | 
						|
	}
 | 
						|
 | 
						|
	if err := s.validateArea(area); err != nil {
 | 
						|
		s.logger.Error("Invalid SSE area", "error", err)
 | 
						|
		s.writeErrorResponse(w, err.Error(), http.StatusBadRequest)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// Set SSE headers
 | 
						|
	w.Header().Set("Content-Type", "text/event-stream")
 | 
						|
	w.Header().Set("Cache-Control", "no-cache")
 | 
						|
	w.Header().Set("Connection", "keep-alive")
 | 
						|
 | 
						|
	flusher, ok := w.(http.Flusher)
 | 
						|
	if !ok {
 | 
						|
		s.writeErrorResponse(w, "Streaming unsupported", http.StatusInternalServerError)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	client := s.createSSEClient(area)
 | 
						|
	defer s.removeSSEClient(client.ID)
 | 
						|
 | 
						|
	s.handleSSEConnection(w, r, client, flusher)
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) serveIndexHandler(w http.ResponseWriter, r *http.Request) {
 | 
						|
	data := map[string]any{
 | 
						|
		csrf.TemplateTag: csrf.TemplateField(r),
 | 
						|
	}
 | 
						|
 | 
						|
	if err := s.templates.ExecuteTemplate(w, "index.html", data); err != nil {
 | 
						|
		s.logger.Error("Failed to execute template", "error", err)
 | 
						|
		http.Error(w, "Internal server error", http.StatusInternalServerError)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) parseSSEParams(r *http.Request) ([]int, error) {
 | 
						|
	vars := mux.Vars(r)
 | 
						|
	path := vars["params"]
 | 
						|
 | 
						|
	paramStrs := strings.Split(path, ",")
 | 
						|
	if len(paramStrs) != 4 {
 | 
						|
		return nil, fmt.Errorf("expected 4 parameters, got %d", len(paramStrs))
 | 
						|
	}
 | 
						|
 | 
						|
	params := make([]int, 4)
 | 
						|
	for i, paramStr := range paramStrs {
 | 
						|
		var err error
 | 
						|
		params[i], err = strconv.Atoi(paramStr)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("invalid parameter %d: %s", i, paramStr)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return params, nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) validateArea(area Area) error {
 | 
						|
	if area.Width <= 0 || area.Height <= 0 {
 | 
						|
		return fmt.Errorf("width and height must be positive")
 | 
						|
	}
 | 
						|
	if area.Width > s.config.MaxWidth || area.Height > s.config.MaxHeight {
 | 
						|
		return fmt.Errorf("area dimensions exceed maximum bounds")
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) createSSEClient(area Area) *SSEConnection {
 | 
						|
	clientID := fmt.Sprintf("%d,%d,%d,%d", area.X, area.Y, area.Width, area.Height)
 | 
						|
	client := &SSEConnection{
 | 
						|
		ID:       clientID,
 | 
						|
		Channel:  make(chan CellState, 10),
 | 
						|
		Area:     area,
 | 
						|
		LastSent: time.Unix(0, 0),
 | 
						|
	}
 | 
						|
 | 
						|
	s.sseConnections[clientID] = client
 | 
						|
 | 
						|
	return client
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) removeSSEClient(clientID string) {
 | 
						|
	if client, exists := s.sseConnections[clientID]; exists {
 | 
						|
		close(client.Channel)
 | 
						|
		delete(s.sseConnections, clientID)
 | 
						|
		s.logger.Info("SSE client disconnected", "client", clientID)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) notifySSEClients(cellState CellState) {
 | 
						|
	for _, client := range s.sseConnections {
 | 
						|
		if client.Area.Contains(cellState.X, cellState.Y) {
 | 
						|
			select {
 | 
						|
			case client.Channel <- cellState:
 | 
						|
			default:
 | 
						|
				// Channel is full, skip this update
 | 
						|
				s.logger.Warn("SSE client channel full, skipping update", "client", client.ID)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) handleSSEConnection(w http.ResponseWriter, r *http.Request, client *SSEConnection, flusher http.Flusher) {
 | 
						|
	clientGone := r.Context().Done()
 | 
						|
 | 
						|
	if err := s.sendPointsState(w, client, flusher); err != nil {
 | 
						|
		s.logger.Error("Failed to send points state", "error", err, "client", client.ID)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case <-clientGone:
 | 
						|
			s.logger.Info("SSE connection closed", "client", client.ID)
 | 
						|
			return
 | 
						|
 | 
						|
		case cellState := <-client.Channel:
 | 
						|
			// Throttle updates
 | 
						|
			if since := time.Since(client.LastSent); since < s.config.SSEFlushInterval {
 | 
						|
				select {
 | 
						|
				case client.Channel <- cellState:
 | 
						|
				default:
 | 
						|
					// Channel is full, skip this update
 | 
						|
					s.logger.Warn("SSE client channel full, skipping update", "client", client.ID)
 | 
						|
				}
 | 
						|
				time.Sleep(s.config.SSEFlushInterval - since)
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			buffer := new(bytes.Buffer)
 | 
						|
			binary.Write(buffer, binary.LittleEndian, int32(cellState.X))
 | 
						|
			binary.Write(buffer, binary.LittleEndian, int32(cellState.Y))
 | 
						|
			binary.Write(buffer, binary.LittleEndian, boolToInt(cellState.State))
 | 
						|
 | 
						|
			datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes())
 | 
						|
 | 
						|
			if _, err := fmt.Fprintf(w, "event: update\ndata: %s\n\n", datab64); err != nil {
 | 
						|
 | 
						|
			}
 | 
						|
			flusher.Flush()
 | 
						|
 | 
						|
			client.LastSent = time.Now()
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) sendPointsState(w http.ResponseWriter, client *SSEConnection, flusher http.Flusher) error {
 | 
						|
	area := client.Area
 | 
						|
	points := s.data.GetLocalPoints(area)
 | 
						|
 | 
						|
	buffer := new(bytes.Buffer)
 | 
						|
	for _, point := range points {
 | 
						|
		binary.Write(buffer, binary.LittleEndian, int32(point.X))
 | 
						|
		binary.Write(buffer, binary.LittleEndian, int32(point.Y))
 | 
						|
	}
 | 
						|
	datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes())
 | 
						|
 | 
						|
	if _, err := fmt.Fprintf(w, "event: state\ndata: %s\n\n", datab64); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	flusher.Flush()
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) writeErrorResponse(w http.ResponseWriter, message string, statusCode int) {
 | 
						|
	w.Header().Set("Content-Type", "application/json")
 | 
						|
	w.WriteHeader(statusCode)
 | 
						|
 | 
						|
	response := map[string]string{"error": message}
 | 
						|
	json.NewEncoder(w).Encode(response)
 | 
						|
}
 | 
						|
 | 
						|
// Server lifecycle methods
 | 
						|
func (s *Server) Start() error {
 | 
						|
	addr := fmt.Sprintf("%s:%s", s.config.Host, s.config.Port)
 | 
						|
 | 
						|
	s.httpServer = &http.Server{
 | 
						|
		Handler:      s.router,
 | 
						|
		Addr:         addr,
 | 
						|
		WriteTimeout: s.config.WriteTimeout,
 | 
						|
		ReadTimeout:  s.config.ReadTimeout,
 | 
						|
	}
 | 
						|
 | 
						|
	s.logger.Info("Server starting", "address", addr)
 | 
						|
 | 
						|
	if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
 | 
						|
		return fmt.Errorf("server failed to start: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *Server) Stop(ctx context.Context) error {
 | 
						|
	s.logger.Info("Server shutting down...")
 | 
						|
 | 
						|
	// Close all SSE connections
 | 
						|
	for _, client := range s.sseConnections {
 | 
						|
		close(client.Channel)
 | 
						|
	}
 | 
						|
	s.sseConnections = make(map[string]*SSEConnection)
 | 
						|
 | 
						|
	if s.httpServer != nil {
 | 
						|
		err := s.httpServer.Shutdown(ctx)
 | 
						|
		if err != nil {
 | 
						|
			s.logger.Error("Failed to shutdown server", "error", err)
 | 
						|
			return fmt.Errorf("server shutdown error: %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if s.data != nil {
 | 
						|
		s.logger.Info("Closing spatial hash map")
 | 
						|
		err := s.data.Close()
 | 
						|
		if err != nil {
 | 
						|
			s.logger.Error("Failed to close spatial hash map", "error", err)
 | 
						|
			return fmt.Errorf("failed to close spatial hash map: %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	s.logger.Info("Server stopped gracefully")
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Run starts the server with graceful shutdown handling
 | 
						|
func Run() error {
 | 
						|
	return RunWithConfig(nil, nil)
 | 
						|
}
 | 
						|
 | 
						|
// RunWithConfig starts the server with custom configuration and spatial map
 | 
						|
func RunWithConfig(config *Config, spatialMap *SpatialHashMap) (err error) {
 | 
						|
	if config == nil {
 | 
						|
		config = DefaultConfig()
 | 
						|
	}
 | 
						|
 | 
						|
	if spatialMap == nil {
 | 
						|
		spatialMap, err = NewSpatialHashMap(config.DB) // Assumes this function exists
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to create spatial hash map: %w", err)
 | 
						|
		}
 | 
						|
		err = spatialMap.InitSchema()
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to create spatial hash map: %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	defer spatialMap.Close()
 | 
						|
 | 
						|
	server, err := NewServer(config, spatialMap)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to create server: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Setup graceful shutdown
 | 
						|
	sigChan := make(chan os.Signal, 1)
 | 
						|
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
 | 
						|
 | 
						|
	// Start server in goroutine
 | 
						|
	errChan := make(chan error, 1)
 | 
						|
	go func() {
 | 
						|
		errChan <- server.Start()
 | 
						|
	}()
 | 
						|
 | 
						|
	// Wait for shutdown signal or server error
 | 
						|
	select {
 | 
						|
	case err := <-errChan:
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	case <-sigChan:
 | 
						|
		server.logger.Info("Shutdown signal received")
 | 
						|
 | 
						|
		ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout)
 | 
						|
		defer cancel()
 | 
						|
 | 
						|
		if err := server.Stop(ctx); err != nil {
 | 
						|
			server.logger.Error("Server shutdown error", "error", err)
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		server.logger.Info("Server stopped gracefully")
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 |