diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2559f19 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/artemvang/infinite-place + +go 1.23.3 + +require ( + github.com/gorilla/csrf v1.7.3 + github.com/gorilla/handlers v1.5.2 + github.com/gorilla/mux v1.8.1 + github.com/mattn/go-sqlite3 v1.14.28 +) + +require ( + github.com/felixge/httpsnoop v1.0.3 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f416dca --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= +github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= +github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= +github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/internal/db.go b/internal/db.go new file mode 100644 index 0000000..db84b22 --- /dev/null +++ b/internal/db.go @@ -0,0 +1,121 @@ +package app + +import ( + "sync" +) + +const ( + ChunkSize = 16 +) + +// Point represents the request body for setting cell position +type Point struct { + X, Y int +} + +// Area represents the visible area for SSE updates +type Area struct { + X, Y, Width, Height int +} + +// Contains checks if a point is within the SSE area +func (a Area) Contains(x int, y int) bool { + return x >= a.X && x <= a.X+a.Width && y >= a.Y && y <= a.Y+a.Height +} + +// SpatialHashMap implements spatial hashing for efficient point storage and retrieval +type SpatialHashMap struct { + mu sync.RWMutex + db *SQLiteEngine +} + +// NewSpatialHashMap creates a new spatial hash map +func NewSpatialHashMap(dbPath string) (*SpatialHashMap, error) { + db, err := CreateSQLite(dbPath) + if err != nil { + return nil, err + } + return &SpatialHashMap{ + db: db, + }, nil +} + +func (s *SpatialHashMap) InitSchema() error { + err := s.db.Execute( + `CREATE TABLE IF NOT EXISTS spatial_hashmap ( + cluster_x INTEGER NOT NULL, + cluster_y INTEGER NOT NULL, + x INTEGER NOT NULL, + y INTEGER NOT NULL, + PRIMARY KEY (cluster_x, cluster_y, x, y)) WITHOUT ROWID`) + if err != nil { + return err + } + return nil +} + +// getHash calculates the hash coordinates for a given point +func (s *SpatialHashMap) getHash(x, y int) Point { + return Point{X: x / ChunkSize, Y: y / ChunkSize} +} + +// Insert adds a point to the spatial hash map +func (s *SpatialHashMap) Insert(x, y int) error { + s.mu.Lock() + defer s.mu.Unlock() + + key := s.getHash(x, y) + err := s.db.Execute( + "INSERT INTO spatial_hashmap (cluster_x, cluster_y, x, y) VALUES (?, ?, ?, ?)", + key.X, key.Y, x, y) + if err != nil { + return err + + } + return nil +} + +// Remove removes a point from the spatial hash map +func (s *SpatialHashMap) Remove(x, y int) error { + s.mu.Lock() + defer s.mu.Unlock() + + key := s.getHash(x, y) + err := s.db.Execute( + "DELETE FROM spatial_hashmap WHERE cluster_x = ? AND cluster_y = ? AND x = ? AND y = ?", + key.X, key.Y, x, y) + if err != nil { + return err + } + return nil +} + +// GetLocalPoints returns all points within the given bounding box +func (s *SpatialHashMap) GetLocalPoints(area Area) []Point { + hashStart, hashEnd := s.getHash(area.X, area.Y), s.getHash(area.X+area.Width, area.Y+area.Height) + + rows, err := s.db.Query( + "SELECT x, y FROM spatial_hashmap WHERE cluster_x BETWEEN ? AND ? AND cluster_y BETWEEN ? AND ?", + hashStart.X, hashEnd.X, hashStart.Y, hashEnd.Y) + if err != nil { + return nil + } + defer rows.Close() + + localPoints := make([]Point, 0) + for rows.Next() { + var x, y int + if err := rows.Scan(&x, &y); err != nil { + continue + } + localPoints = append(localPoints, Point{X: x, Y: y}) + } + return localPoints +} + +func (s *SpatialHashMap) Close() error { + if s.db != nil { + return s.db.Close() + } + return nil +} diff --git a/internal/server.go b/internal/server.go new file mode 100644 index 0000000..e5a7fcc --- /dev/null +++ b/internal/server.go @@ -0,0 +1,481 @@ +package app + +import ( + "bytes" + "context" + "embed" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "log" + "log/slog" + "net/http" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "text/template" + "time" + + "github.com/gorilla/csrf" + "github.com/gorilla/handlers" + "github.com/gorilla/mux" +) + +//go:embed templates +var tplFolder embed.FS + +// Config holds server configuration +type Config struct { + Host string + Port string + CSRFKey []byte + TrustedOrigins []string + MaxWidth int + MaxHeight int + SSEFlushInterval time.Duration + WriteTimeout time.Duration + ReadTimeout time.Duration + ShutdownTimeout time.Duration + DB string +} + +// DefaultConfig returns default server configuration +func DefaultConfig() *Config { + return &Config{ + Host: "localhost", + Port: "5002", + CSRFKey: []byte("0e6139e71f1972259e4b2ce6464b80a3"), // Should be from env in production + TrustedOrigins: []string{"localhost:5002"}, + MaxWidth: 1024, + MaxHeight: 1024, + SSEFlushInterval: 200 * time.Millisecond, + WriteTimeout: 15 * time.Second, + ReadTimeout: 15 * time.Second, + ShutdownTimeout: 10 * time.Second, + DB: "points.db?_journal_mode=WAL&cache=shared&_foreign_keys=0&synchronous=1", + } +} + +type CellState struct { + X int `json:"x"` + Y int `json:"y"` + State bool `json:"state"` +} + +// SSEConnection represents an SSE connection +type SSEConnection struct { + ID string + Channel chan CellState + Area Area + LastSent time.Time +} + +// Server encapsulates the HTTP server and its dependencies +type Server struct { + config *Config + logger *slog.Logger + templates *template.Template + data *SpatialHashMap + sseConnections map[string]*SSEConnection + router *mux.Router + httpServer *http.Server +} + +// NewServer creates a new server instance +func NewServer(config *Config, spatialMap *SpatialHashMap) (*Server, error) { + if config == nil { + config = DefaultConfig() + } + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + + templates, err := template.ParseFS(tplFolder, "templates/*.html") + if err != nil { + log.Fatal(err) + } + + server := &Server{ + config: config, + logger: logger, + templates: templates, + data: spatialMap, + sseConnections: make(map[string]*SSEConnection), + } + + server.setupRoutes() + return server, nil +} + +// setupRoutes configures the HTTP routes and middleware +func (s *Server) setupRoutes() { + s.router = mux.NewRouter() + + // Middleware chain + s.router.Use(s.loggingMiddleware) + s.router.Use(s.corsMiddleware) + s.router.Use(s.csrfMiddleware) + + // Routes + s.router.HandleFunc("/api/set-pos", s.setPositionHandler).Methods("POST") + s.router.HandleFunc("/sse/{params}", s.streamEventsHandler).Methods("GET") + s.router.HandleFunc("/", s.serveIndexHandler).Methods("GET") +} + +// Middleware functions +func (s *Server) loggingMiddleware(next http.Handler) http.Handler { + return handlers.LoggingHandler(os.Stdout, next) +} + +func (s *Server) corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (s *Server) csrfMiddleware(next http.Handler) http.Handler { + csrfProtection := csrf.Protect( + s.config.CSRFKey, + csrf.Secure(false), // Set to true in production with HTTPS + csrf.HttpOnly(true), + csrf.TrustedOrigins(s.config.TrustedOrigins), + ) + return csrfProtection(next) +} + +// HTTP Handlers +func (s *Server) setPositionHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var cellState CellState + if err := json.NewDecoder(r.Body).Decode(&cellState); err != nil { + s.logger.Error("Failed to decode JSON", "error", err) + s.writeErrorResponse(w, "Invalid JSON", http.StatusBadRequest) + return + } + + // Update spatial data + if cellState.State { + s.data.Insert(cellState.X, cellState.Y) + } else { + s.data.Remove(cellState.X, cellState.Y) + } + + // Notify SSE clients + s.notifySSEClients(cellState) + + response := map[string]string{"status": "ok"} + if err := json.NewEncoder(w).Encode(response); err != nil { + s.logger.Error("Failed to encode response", "error", err) + } +} + +func (s *Server) streamEventsHandler(w http.ResponseWriter, r *http.Request) { + params, err := s.parseSSEParams(r) + if err != nil { + s.logger.Error("Invalid SSE parameters", "error", err) + s.writeErrorResponse(w, "Invalid parameters", http.StatusBadRequest) + return + } + + area := Area{ + X: params[0], + Y: params[1], + Width: params[2], + Height: params[3], + } + + if err := s.validateArea(area); err != nil { + s.logger.Error("Invalid SSE area", "error", err) + s.writeErrorResponse(w, err.Error(), http.StatusBadRequest) + return + } + + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + s.writeErrorResponse(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + client := s.createSSEClient(area) + defer s.removeSSEClient(client.ID) + + s.handleSSEConnection(w, r, client, flusher) +} + +func (s *Server) serveIndexHandler(w http.ResponseWriter, r *http.Request) { + data := map[string]any{ + csrf.TemplateTag: csrf.TemplateField(r), + } + + if err := s.templates.ExecuteTemplate(w, "index.html", data); err != nil { + s.logger.Error("Failed to execute template", "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} + +func (s *Server) parseSSEParams(r *http.Request) ([]int, error) { + vars := mux.Vars(r) + path := vars["params"] + + paramStrs := strings.Split(path, ",") + if len(paramStrs) != 4 { + return nil, fmt.Errorf("expected 4 parameters, got %d", len(paramStrs)) + } + + params := make([]int, 4) + for i, paramStr := range paramStrs { + var err error + params[i], err = strconv.Atoi(paramStr) + if err != nil { + return nil, fmt.Errorf("invalid parameter %d: %s", i, paramStr) + } + } + + return params, nil +} + +func (s *Server) validateArea(area Area) error { + if area.Width <= 0 || area.Height <= 0 { + return fmt.Errorf("width and height must be positive") + } + if area.Width > s.config.MaxWidth || area.Height > s.config.MaxHeight { + return fmt.Errorf("area dimensions exceed maximum bounds") + } + return nil +} + +func (s *Server) createSSEClient(area Area) *SSEConnection { + clientID := fmt.Sprintf("%d,%d,%d,%d", area.X, area.Y, area.Width, area.Height) + client := &SSEConnection{ + ID: clientID, + Channel: make(chan CellState, 10), + Area: area, + LastSent: time.Unix(0, 0), + } + + s.sseConnections[clientID] = client + + return client +} + +func (s *Server) removeSSEClient(clientID string) { + if client, exists := s.sseConnections[clientID]; exists { + close(client.Channel) + delete(s.sseConnections, clientID) + s.logger.Info("SSE client disconnected", "client", clientID) + } +} + +func (s *Server) notifySSEClients(cellState CellState) { + for _, client := range s.sseConnections { + if client.Area.Contains(cellState.X, cellState.Y) { + select { + case client.Channel <- cellState: + default: + // Channel is full, skip this update + s.logger.Warn("SSE client channel full, skipping update", "client", client.ID) + } + } + } +} + +func (s *Server) handleSSEConnection(w http.ResponseWriter, r *http.Request, client *SSEConnection, flusher http.Flusher) { + clientGone := r.Context().Done() + + if err := s.sendPointsState(w, client, flusher); err != nil { + s.logger.Error("Failed to send points state", "error", err, "client", client.ID) + return + } + + for { + select { + case <-clientGone: + s.logger.Info("SSE connection closed", "client", client.ID) + return + + case cellState := <-client.Channel: + // Throttle updates + if since := time.Since(client.LastSent); since < s.config.SSEFlushInterval { + select { + case client.Channel <- cellState: + default: + // Channel is full, skip this update + s.logger.Warn("SSE client channel full, skipping update", "client", client.ID) + } + time.Sleep(s.config.SSEFlushInterval - since) + continue + } + + buffer := new(bytes.Buffer) + binary.Write(buffer, binary.LittleEndian, int32(cellState.X)) + binary.Write(buffer, binary.LittleEndian, int32(cellState.Y)) + binary.Write(buffer, binary.LittleEndian, boolToInt(cellState.State)) + + datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes()) + + if _, err := fmt.Fprintf(w, "event: update\ndata: %s\n\n", datab64); err != nil { + + } + flusher.Flush() + + client.LastSent = time.Now() + } + } +} + +func (s *Server) sendPointsState(w http.ResponseWriter, client *SSEConnection, flusher http.Flusher) error { + area := client.Area + points := s.data.GetLocalPoints(area) + + buffer := new(bytes.Buffer) + for _, point := range points { + binary.Write(buffer, binary.LittleEndian, int32(point.X)) + binary.Write(buffer, binary.LittleEndian, int32(point.Y)) + } + datab64 := base64.StdEncoding.EncodeToString(buffer.Bytes()) + + if _, err := fmt.Fprintf(w, "event: state\ndata: %s\n\n", datab64); err != nil { + return err + } + + flusher.Flush() + return nil +} + +func (s *Server) writeErrorResponse(w http.ResponseWriter, message string, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + response := map[string]string{"error": message} + json.NewEncoder(w).Encode(response) +} + +// Server lifecycle methods +func (s *Server) Start() error { + addr := fmt.Sprintf("%s:%s", s.config.Host, s.config.Port) + + s.httpServer = &http.Server{ + Handler: s.router, + Addr: addr, + WriteTimeout: s.config.WriteTimeout, + ReadTimeout: s.config.ReadTimeout, + } + + s.logger.Info("Server starting", "address", addr) + + if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("server failed to start: %w", err) + } + + return nil +} + +func (s *Server) Stop(ctx context.Context) error { + s.logger.Info("Server shutting down...") + + // Close all SSE connections + for _, client := range s.sseConnections { + close(client.Channel) + } + s.sseConnections = make(map[string]*SSEConnection) + + if s.httpServer != nil { + err := s.httpServer.Shutdown(ctx) + if err != nil { + s.logger.Error("Failed to shutdown server", "error", err) + return fmt.Errorf("server shutdown error: %w", err) + } + } + + if s.data != nil { + s.logger.Info("Closing spatial hash map") + err := s.data.Close() + if err != nil { + s.logger.Error("Failed to close spatial hash map", "error", err) + return fmt.Errorf("failed to close spatial hash map: %w", err) + } + } + s.logger.Info("Server stopped gracefully") + + return nil +} + +// Run starts the server with graceful shutdown handling +func Run() error { + return RunWithConfig(nil, nil) +} + +// RunWithConfig starts the server with custom configuration and spatial map +func RunWithConfig(config *Config, spatialMap *SpatialHashMap) (err error) { + if config == nil { + config = DefaultConfig() + } + + if spatialMap == nil { + spatialMap, err = NewSpatialHashMap(config.DB) // Assumes this function exists + if err != nil { + return fmt.Errorf("failed to create spatial hash map: %w", err) + } + err = spatialMap.InitSchema() + if err != nil { + return fmt.Errorf("failed to create spatial hash map: %w", err) + } + } + defer spatialMap.Close() + + server, err := NewServer(config, spatialMap) + if err != nil { + return fmt.Errorf("failed to create server: %w", err) + } + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Start server in goroutine + errChan := make(chan error, 1) + go func() { + errChan <- server.Start() + }() + + // Wait for shutdown signal or server error + select { + case err := <-errChan: + if err != nil { + return err + } + case <-sigChan: + server.logger.Info("Shutdown signal received") + + ctx, cancel := context.WithTimeout(context.Background(), config.ShutdownTimeout) + defer cancel() + + if err := server.Stop(ctx); err != nil { + server.logger.Error("Server shutdown error", "error", err) + return err + } + + server.logger.Info("Server stopped gracefully") + } + + return nil +} diff --git a/internal/sqlite.go b/internal/sqlite.go new file mode 100644 index 0000000..ec54cef --- /dev/null +++ b/internal/sqlite.go @@ -0,0 +1,45 @@ +package app + +import ( + "database/sql" + + _ "github.com/mattn/go-sqlite3" +) + +type SQLiteEngine struct { + connection *sql.DB +} + +func CreateSQLite(url string) (*SQLiteEngine, error) { + sqliteDatabase, err := sql.Open("sqlite3", url) + if err != nil { + return nil, err + } + + return &SQLiteEngine{connection: sqliteDatabase}, nil +} + +func (e *SQLiteEngine) Query(sql string, args ...any) (*sql.Rows, error) { + result, err := e.connection.Query(sql, args...) + if err != nil { + return nil, err + } + + return result, nil +} + +func (e *SQLiteEngine) Execute(sql string, args ...any) error { + _, err := e.connection.Exec(sql, args...) + if err != nil { + return err + } + + return nil +} + +func (e *SQLiteEngine) Close() error { + if e.connection != nil { + return e.connection.Close() + } + return nil +} diff --git a/internal/templates/index.html b/internal/templates/index.html new file mode 100644 index 0000000..94d4f89 --- /dev/null +++ b/internal/templates/index.html @@ -0,0 +1,594 @@ + + + +
+ + +