247 lines
7.0 KiB
Go
247 lines
7.0 KiB
Go
package ws
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"git.saqut.com/saqut/mwse/internal/protocol"
|
|
)
|
|
|
|
// Heartbeat and transport defaults used by DefaultOptions and NewClient. The
|
|
// PingPeriod matches the original server's 10s interval; the ping payload is the
|
|
// magic string the SDK's pong echoes.
|
|
const (
|
|
defaultWriteWait = 10 * time.Second
|
|
defaultPongWait = 60 * time.Second
|
|
defaultPingPeriod = 10 * time.Second
|
|
defaultMaxMessageSize = 16 << 20 // 16 MiB; supports large tunneled payloads (#30)
|
|
|
|
pingPayload = "saQut"
|
|
)
|
|
|
|
// Options tunes the transport for scale. Zero fields fall back to defaults.
|
|
type Options struct {
|
|
OutboundBuffer int // per-connection send queue depth
|
|
MaxMessageSize int64 // max inbound frame size
|
|
ReadBufferSize int // gorilla per-connection read buffer
|
|
WriteBufferSize int // gorilla write buffer size (pooled)
|
|
PingInterval time.Duration // heartbeat period
|
|
PongWait time.Duration // max wait for a pong before dropping
|
|
WriteWait time.Duration // per-write socket deadline
|
|
}
|
|
|
|
// DefaultOptions returns the built-in tuning.
|
|
func DefaultOptions() Options {
|
|
return Options{
|
|
OutboundBuffer: defaultOutboundBuffer,
|
|
MaxMessageSize: defaultMaxMessageSize,
|
|
ReadBufferSize: 4096,
|
|
WriteBufferSize: 4096,
|
|
PingInterval: defaultPingPeriod,
|
|
PongWait: defaultPongWait,
|
|
WriteWait: defaultWriteWait,
|
|
}
|
|
}
|
|
|
|
func (o Options) withDefaults() Options {
|
|
d := DefaultOptions()
|
|
if o.OutboundBuffer <= 0 {
|
|
o.OutboundBuffer = d.OutboundBuffer
|
|
}
|
|
if o.MaxMessageSize <= 0 {
|
|
o.MaxMessageSize = d.MaxMessageSize
|
|
}
|
|
if o.ReadBufferSize <= 0 {
|
|
o.ReadBufferSize = d.ReadBufferSize
|
|
}
|
|
if o.WriteBufferSize <= 0 {
|
|
o.WriteBufferSize = d.WriteBufferSize
|
|
}
|
|
if o.PingInterval <= 0 {
|
|
o.PingInterval = d.PingInterval
|
|
}
|
|
if o.PongWait <= 0 {
|
|
o.PongWait = d.PongWait
|
|
}
|
|
if o.WriteWait <= 0 {
|
|
o.WriteWait = d.WriteWait
|
|
}
|
|
return o
|
|
}
|
|
|
|
// Server upgrades HTTP requests to WebSocket connections and runs each one's
|
|
// lifecycle. It holds no per-connection state itself; everything lives on the Hub.
|
|
type Server struct {
|
|
hub *Hub
|
|
upgrader websocket.Upgrader
|
|
opts Options
|
|
}
|
|
|
|
// NewServer returns a Server bound to hub. An optional Options tunes the
|
|
// transport; omit it for defaults.
|
|
//
|
|
// The upgrader uses a shared WriteBufferPool so write scratch buffers are reused
|
|
// across connections instead of allocated per connection — a large memory saving
|
|
// at high connection counts. CheckOrigin always returns true to preserve the
|
|
// original autoAcceptConnections behaviour (origin policy belongs in front of the
|
|
// engine).
|
|
func NewServer(hub *Hub, opts ...Options) *Server {
|
|
o := DefaultOptions()
|
|
if len(opts) > 0 {
|
|
o = opts[0].withDefaults()
|
|
}
|
|
return &Server{
|
|
hub: hub,
|
|
opts: o,
|
|
upgrader: websocket.Upgrader{
|
|
ReadBufferSize: o.ReadBufferSize,
|
|
WriteBufferSize: o.WriteBufferSize,
|
|
WriteBufferPool: &sync.Pool{},
|
|
CheckOrigin: func(*http.Request) bool { return true },
|
|
},
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler: it upgrades the request and hands the
|
|
// connection to the lifecycle. It is the WebSocket endpoint.
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := s.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("ws: upgrade failed: %v", err)
|
|
return
|
|
}
|
|
s.handle(conn)
|
|
}
|
|
|
|
// handle drives one connection from accept to disconnect. It is generic over the
|
|
// Conn interface so tests can feed it a scripted in-memory connection.
|
|
func (s *Server) handle(conn Conn) {
|
|
client := newClient(conn, newUUID(), s.opts.OutboundBuffer, s.opts.WriteWait)
|
|
|
|
// Connect: register, start the writer, fire connect listeners (id, private
|
|
// room, session defaults). Must happen before the read loop so the client can
|
|
// already receive server-initiated messages.
|
|
s.hub.Connect(client)
|
|
|
|
// Heartbeat lives in its own goroutine; it stops when the client closes.
|
|
go s.pingLoop(client)
|
|
|
|
// Read loop blocks here until the socket errors or closes.
|
|
s.readLoop(client)
|
|
|
|
// Disconnect runs exactly once, here, when the read loop ends.
|
|
s.hub.Disconnect(client)
|
|
}
|
|
|
|
// readLoop consumes inbound frames. Pongs that do not echo the magic payload drop
|
|
// the connection (matching the original); a valid pong extends the read deadline.
|
|
func (s *Server) readLoop(c *Client) {
|
|
c.conn.SetReadLimit(s.opts.MaxMessageSize)
|
|
_ = c.conn.SetReadDeadline(time.Now().Add(s.opts.PongWait))
|
|
c.conn.SetPongHandler(func(appData string) error {
|
|
if appData != pingPayload {
|
|
return errBadPong
|
|
}
|
|
return c.conn.SetReadDeadline(time.Now().Add(s.opts.PongWait))
|
|
})
|
|
|
|
for {
|
|
msgType, data, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if msgType != websocket.TextMessage {
|
|
continue // the protocol is JSON text; ignore binary frames
|
|
}
|
|
s.dispatch(c, data)
|
|
}
|
|
}
|
|
|
|
// errBadPong is returned by the pong handler to force a disconnect.
|
|
var errBadPong = &pongError{}
|
|
|
|
type pongError struct{}
|
|
|
|
func (*pongError) Error() string { return "ws: pong validation failed" }
|
|
|
|
// dispatch decodes one frame and routes it, then replies according to the WSTS
|
|
// rules. A frame that fails to decode is logged as a message error (the Node
|
|
// server emitted a 'messageError' event here).
|
|
func (s *Server) dispatch(c *Client, data []byte) {
|
|
env, err := protocol.Decode(data)
|
|
if err != nil {
|
|
log.Printf("ws: message error from %s: %v", c.ID, err)
|
|
return
|
|
}
|
|
|
|
result := s.hub.Handle(c, env.Message)
|
|
|
|
if flag, ok := env.WantsReply(); ok {
|
|
c.Send(protocol.Reply(result, env.ID, flag))
|
|
return
|
|
}
|
|
|
|
// "No id" branch: the original inspected the result for a broadcast directive.
|
|
// No service currently emits one, but the hook is preserved for parity.
|
|
if env.IsBroadcast() {
|
|
if m, ok := result.(map[string]any); ok {
|
|
if _, has := m["broadcast"]; has {
|
|
log.Printf("ws: broadcast directive from %s (no listener registered)", c.ID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// pingLoop sends a ping with the magic payload every ping period until the client
|
|
// closes.
|
|
func (s *Server) pingLoop(c *Client) {
|
|
ticker := time.NewTicker(s.opts.PingInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
err := c.conn.WriteControl(
|
|
websocket.PingMessage,
|
|
[]byte(pingPayload),
|
|
time.Now().Add(s.opts.WriteWait),
|
|
)
|
|
if err != nil {
|
|
c.Close()
|
|
return
|
|
}
|
|
case <-c.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// ---- lifecycle helpers (shared by the server and by tests) ---------------
|
|
|
|
// Connect registers a client, starts its writer goroutine, and fires the connect
|
|
// listeners. Exposed so tests and tools can drive the same path the server uses.
|
|
func (h *Hub) Connect(c *Client) {
|
|
h.addClient(c)
|
|
go c.writePump()
|
|
h.emitConnect(c)
|
|
}
|
|
|
|
// Disconnect fires the disconnect listeners, unregisters the client, and tears
|
|
// its transport down. Safe to call once per client.
|
|
func (h *Hub) Disconnect(c *Client) {
|
|
h.emitDisconnect(c)
|
|
h.removeClient(c.ID)
|
|
c.Close()
|
|
}
|
|
|
|
// CloseAll closes every connected client. Used during graceful shutdown so peers
|
|
// receive a clean close frame before the process exits.
|
|
func (h *Hub) CloseAll() {
|
|
for _, c := range h.Clients() {
|
|
c.Close()
|
|
}
|
|
}
|