MWSE/internal/ws/server.go

187 lines
5.2 KiB
Go

package ws
import (
"log"
"net/http"
"time"
"github.com/gorilla/websocket"
"git.saqut.com/saqut/mwse/internal/protocol"
)
// Heartbeat and transport defaults. PingPeriod matches the original server's 10s
// interval; the ping payload is the same magic string the SDK's pong echoes.
const (
defaultWriteWait = 10 * time.Second
defaultPongWait = 60 * time.Second
defaultPingPeriod = 10 * time.Second
defaultMaxMessageSize = 4 << 20 // 4 MiB
pingPayload = "saQut"
)
// Server upgrades HTTP requests to WebSocket connections and runs each one's
// lifecycle. It holds no per-connection state itself; everything lives on the Hub.
type Server struct {
hub *Hub
upgrader websocket.Upgrader
pingPeriod time.Duration
pongWait time.Duration
}
// NewServer returns a Server bound to hub. CheckOrigin always returns true to
// preserve the original autoAcceptConnections behaviour (origin policy is a
// deployment concern handled in front of the engine).
func NewServer(hub *Hub) *Server {
return &Server{
hub: hub,
upgrader: websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(*http.Request) bool { return true },
},
pingPeriod: defaultPingPeriod,
pongWait: defaultPongWait,
}
}
// ServeHTTP implements http.Handler: it upgrades the request and hands the
// connection to the lifecycle. It is the WebSocket endpoint.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("ws: upgrade failed: %v", err)
return
}
s.handle(conn)
}
// handle drives one connection from accept to disconnect. It is generic over the
// Conn interface so tests can feed it a scripted in-memory connection.
func (s *Server) handle(conn Conn) {
client := NewClient(conn, newUUID())
// Connect: register, start the writer, fire connect listeners (id, private
// room, session defaults). Must happen before the read loop so the client can
// already receive server-initiated messages.
s.hub.Connect(client)
// Heartbeat lives in its own goroutine; it stops when the client closes.
go s.pingLoop(client)
// Read loop blocks here until the socket errors or closes.
s.readLoop(client)
// Disconnect runs exactly once, here, when the read loop ends.
s.hub.Disconnect(client)
}
// readLoop consumes inbound frames. Pongs that do not echo the magic payload drop
// the connection (matching the original); a valid pong extends the read deadline.
func (s *Server) readLoop(c *Client) {
c.conn.SetReadLimit(defaultMaxMessageSize)
_ = c.conn.SetReadDeadline(time.Now().Add(s.pongWait))
c.conn.SetPongHandler(func(appData string) error {
if appData != pingPayload {
return errBadPong
}
return c.conn.SetReadDeadline(time.Now().Add(s.pongWait))
})
for {
msgType, data, err := c.conn.ReadMessage()
if err != nil {
return
}
if msgType != websocket.TextMessage {
continue // the protocol is JSON text; ignore binary frames
}
s.dispatch(c, data)
}
}
// errBadPong is returned by the pong handler to force a disconnect.
var errBadPong = &pongError{}
type pongError struct{}
func (*pongError) Error() string { return "ws: pong validation failed" }
// dispatch decodes one frame and routes it, then replies according to the WSTS
// rules. A frame that fails to decode is logged as a message error (the Node
// server emitted a 'messageError' event here).
func (s *Server) dispatch(c *Client, data []byte) {
env, err := protocol.Decode(data)
if err != nil {
log.Printf("ws: message error from %s: %v", c.ID, err)
return
}
result := s.hub.Handle(c, env.Message)
if flag, ok := env.WantsReply(); ok {
c.Send(protocol.Reply(result, env.ID, flag))
return
}
// "No id" branch: the original inspected the result for a broadcast directive.
// No service currently emits one, but the hook is preserved for parity.
if env.IsBroadcast() {
if m, ok := result.(map[string]any); ok {
if _, has := m["broadcast"]; has {
log.Printf("ws: broadcast directive from %s (no listener registered)", c.ID)
}
}
}
}
// pingLoop sends a ping with the magic payload every ping period until the client
// closes.
func (s *Server) pingLoop(c *Client) {
ticker := time.NewTicker(s.pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := c.conn.WriteControl(
websocket.PingMessage,
[]byte(pingPayload),
time.Now().Add(defaultWriteWait),
)
if err != nil {
c.Close()
return
}
case <-c.Done():
return
}
}
}
// ---- lifecycle helpers (shared by the server and by tests) ---------------
// Connect registers a client, starts its writer goroutine, and fires the connect
// listeners. Exposed so tests and tools can drive the same path the server uses.
func (h *Hub) Connect(c *Client) {
h.addClient(c)
go c.writePump()
h.emitConnect(c)
}
// Disconnect fires the disconnect listeners, unregisters the client, and tears
// its transport down. Safe to call once per client.
func (h *Hub) Disconnect(c *Client) {
h.emitDisconnect(c)
h.removeClient(c.ID)
c.Close()
}
// CloseAll closes every connected client. Used during graceful shutdown so peers
// receive a clean close frame before the process exits.
func (h *Hub) CloseAll() {
for _, c := range h.Clients() {
c.Close()
}
}