package ws import ( "log" "net/http" "time" "github.com/gorilla/websocket" "git.saqut.com/saqut/mwse/internal/protocol" ) // Heartbeat and transport defaults. PingPeriod matches the original server's 10s // interval; the ping payload is the same magic string the SDK's pong echoes. const ( defaultWriteWait = 10 * time.Second defaultPongWait = 60 * time.Second defaultPingPeriod = 10 * time.Second defaultMaxMessageSize = 4 << 20 // 4 MiB pingPayload = "saQut" ) // Server upgrades HTTP requests to WebSocket connections and runs each one's // lifecycle. It holds no per-connection state itself; everything lives on the Hub. type Server struct { hub *Hub upgrader websocket.Upgrader pingPeriod time.Duration pongWait time.Duration } // NewServer returns a Server bound to hub. CheckOrigin always returns true to // preserve the original autoAcceptConnections behaviour (origin policy is a // deployment concern handled in front of the engine). func NewServer(hub *Hub) *Server { return &Server{ hub: hub, upgrader: websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(*http.Request) bool { return true }, }, pingPeriod: defaultPingPeriod, pongWait: defaultPongWait, } } // ServeHTTP implements http.Handler: it upgrades the request and hands the // connection to the lifecycle. It is the WebSocket endpoint. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("ws: upgrade failed: %v", err) return } s.handle(conn) } // handle drives one connection from accept to disconnect. It is generic over the // Conn interface so tests can feed it a scripted in-memory connection. func (s *Server) handle(conn Conn) { client := NewClient(conn, newUUID()) // Connect: register, start the writer, fire connect listeners (id, private // room, session defaults). Must happen before the read loop so the client can // already receive server-initiated messages. s.hub.Connect(client) // Heartbeat lives in its own goroutine; it stops when the client closes. go s.pingLoop(client) // Read loop blocks here until the socket errors or closes. s.readLoop(client) // Disconnect runs exactly once, here, when the read loop ends. s.hub.Disconnect(client) } // readLoop consumes inbound frames. Pongs that do not echo the magic payload drop // the connection (matching the original); a valid pong extends the read deadline. func (s *Server) readLoop(c *Client) { c.conn.SetReadLimit(defaultMaxMessageSize) _ = c.conn.SetReadDeadline(time.Now().Add(s.pongWait)) c.conn.SetPongHandler(func(appData string) error { if appData != pingPayload { return errBadPong } return c.conn.SetReadDeadline(time.Now().Add(s.pongWait)) }) for { msgType, data, err := c.conn.ReadMessage() if err != nil { return } if msgType != websocket.TextMessage { continue // the protocol is JSON text; ignore binary frames } s.dispatch(c, data) } } // errBadPong is returned by the pong handler to force a disconnect. var errBadPong = &pongError{} type pongError struct{} func (*pongError) Error() string { return "ws: pong validation failed" } // dispatch decodes one frame and routes it, then replies according to the WSTS // rules. A frame that fails to decode is logged as a message error (the Node // server emitted a 'messageError' event here). func (s *Server) dispatch(c *Client, data []byte) { env, err := protocol.Decode(data) if err != nil { log.Printf("ws: message error from %s: %v", c.ID, err) return } result := s.hub.Handle(c, env.Message) if flag, ok := env.WantsReply(); ok { c.Send(protocol.Reply(result, env.ID, flag)) return } // "No id" branch: the original inspected the result for a broadcast directive. // No service currently emits one, but the hook is preserved for parity. if env.IsBroadcast() { if m, ok := result.(map[string]any); ok { if _, has := m["broadcast"]; has { log.Printf("ws: broadcast directive from %s (no listener registered)", c.ID) } } } } // pingLoop sends a ping with the magic payload every ping period until the client // closes. func (s *Server) pingLoop(c *Client) { ticker := time.NewTicker(s.pingPeriod) defer ticker.Stop() for { select { case <-ticker.C: err := c.conn.WriteControl( websocket.PingMessage, []byte(pingPayload), time.Now().Add(defaultWriteWait), ) if err != nil { c.Close() return } case <-c.Done(): return } } } // ---- lifecycle helpers (shared by the server and by tests) --------------- // Connect registers a client, starts its writer goroutine, and fires the connect // listeners. Exposed so tests and tools can drive the same path the server uses. func (h *Hub) Connect(c *Client) { h.addClient(c) go c.writePump() h.emitConnect(c) } // Disconnect fires the disconnect listeners, unregisters the client, and tears // its transport down. Safe to call once per client. func (h *Hub) Disconnect(c *Client) { h.emitDisconnect(c) h.removeClient(c.ID) c.Close() } // CloseAll closes every connected client. Used during graceful shutdown so peers // receive a clean close frame before the process exits. func (h *Hub) CloseAll() { for _, c := range h.Clients() { c.Close() } }