package main import ( "encoding/json" "fmt" "sync" "sync/atomic" "time" "github.com/gorilla/websocket" ) // Client is a minimal WSTS speaker used by the load tester. It mirrors the // frontend SDK's framing: requests carry a numeric id and a trailing "R", and the // engine replies with [payload, id, "E"]; server-initiated messages arrive as // [payload, name]. type Client struct { ID string // assigned by the server's "id" signal conn *websocket.Conn counter int64 mu sync.Mutex pending map[int64]chan json.RawMessage sigMu sync.Mutex signals map[string]func(json.RawMessage) writeMu sync.Mutex // gorilla allows only one concurrent writer closed atomic.Bool idReady chan struct{} } // Dial connects to the engine at url and starts the read loop. func Dial(url string) (*Client, error) { conn, _, err := websocket.DefaultDialer.Dial(url, nil) if err != nil { return nil, err } c := &Client{ conn: conn, pending: make(map[int64]chan json.RawMessage), signals: make(map[string]func(json.RawMessage)), idReady: make(chan struct{}), } // Capture our socket id from the YourID signal. c.OnSignal("id", func(raw json.RawMessage) { var p struct { Value string `json:"value"` } if json.Unmarshal(raw, &p) == nil && p.Value != "" { c.ID = p.Value close(c.idReady) } }) go c.readLoop() return c, nil } // WaitID blocks until the server has told us our socket id (or the timeout fires). func (c *Client) WaitID(timeout time.Duration) error { select { case <-c.idReady: return nil case <-time.After(timeout): return fmt.Errorf("timed out waiting for socket id") } } // OnSignal registers a handler for a server-initiated signal name. func (c *Client) OnSignal(name string, fn func(json.RawMessage)) { c.sigMu.Lock() c.signals[name] = fn c.sigMu.Unlock() } // Request sends a request and waits for the engine's reply, returning the round // trip time. The payload must be a JSON object carrying a "type". func (c *Client) Request(payload any, timeout time.Duration) (json.RawMessage, time.Duration, error) { id := atomic.AddInt64(&c.counter, 1) ch := make(chan json.RawMessage, 1) c.mu.Lock() c.pending[id] = ch c.mu.Unlock() defer func() { c.mu.Lock() delete(c.pending, id) c.mu.Unlock() }() start := time.Now() if err := c.write([]any{payload, id, "R"}); err != nil { return nil, 0, err } select { case resp := <-ch: return resp, time.Since(start), nil case <-time.After(timeout): return nil, 0, fmt.Errorf("request timed out") } } // SendOnly sends a fire-and-forget message (the "R" string id path) for which the // engine produces no reply. func (c *Client) SendOnly(payload any) error { return c.write([]any{payload, "R"}) } func (c *Client) write(v any) error { b, err := json.Marshal(v) if err != nil { return err } c.writeMu.Lock() defer c.writeMu.Unlock() return c.conn.WriteMessage(websocket.TextMessage, b) } // Close shuts the connection down. func (c *Client) Close() { if c.closed.Swap(true) { return } _ = c.conn.Close() } func (c *Client) readLoop() { for { _, data, err := c.conn.ReadMessage() if err != nil { return } var arr []json.RawMessage if json.Unmarshal(data, &arr) != nil || len(arr) < 2 { continue } // arr[1] is either a numeric id (a reply) or a string (a signal). var num int64 if json.Unmarshal(arr[1], &num) == nil && looksNumeric(arr[1]) { c.deliverReply(num, arr[0]) continue } var name string if json.Unmarshal(arr[1], &name) == nil { c.deliverSignal(name, arr[0]) } } } func (c *Client) deliverReply(id int64, payload json.RawMessage) { c.mu.Lock() ch := c.pending[id] c.mu.Unlock() if ch != nil { select { case ch <- payload: default: } } } func (c *Client) deliverSignal(name string, payload json.RawMessage) { c.sigMu.Lock() fn := c.signals[name] c.sigMu.Unlock() if fn != nil { fn(payload) } } // looksNumeric reports whether a raw JSON token is a number (so "5" is a reply id // but "\"room/joined\"" is a signal name). func looksNumeric(raw json.RawMessage) bool { if len(raw) == 0 { return false } c := raw[0] return c == '-' || (c >= '0' && c <= '9') }