MWSE/loadtest/client.go

183 lines
4.1 KiB
Go

package main
import (
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
// Client is a minimal WSTS speaker used by the load tester. It mirrors the
// frontend SDK's framing: requests carry a numeric id and a trailing "R", and the
// engine replies with [payload, id, "E"]; server-initiated messages arrive as
// [payload, name].
type Client struct {
ID string // assigned by the server's "id" signal
conn *websocket.Conn
counter int64
mu sync.Mutex
pending map[int64]chan json.RawMessage
sigMu sync.Mutex
signals map[string]func(json.RawMessage)
writeMu sync.Mutex // gorilla allows only one concurrent writer
closed atomic.Bool
idReady chan struct{}
}
// Dial connects to the engine at url and starts the read loop.
func Dial(url string) (*Client, error) {
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
return nil, err
}
c := &Client{
conn: conn,
pending: make(map[int64]chan json.RawMessage),
signals: make(map[string]func(json.RawMessage)),
idReady: make(chan struct{}),
}
// Capture our socket id from the YourID signal.
c.OnSignal("id", func(raw json.RawMessage) {
var p struct {
Value string `json:"value"`
}
if json.Unmarshal(raw, &p) == nil && p.Value != "" {
c.ID = p.Value
close(c.idReady)
}
})
go c.readLoop()
return c, nil
}
// WaitID blocks until the server has told us our socket id (or the timeout fires).
func (c *Client) WaitID(timeout time.Duration) error {
select {
case <-c.idReady:
return nil
case <-time.After(timeout):
return fmt.Errorf("timed out waiting for socket id")
}
}
// OnSignal registers a handler for a server-initiated signal name.
func (c *Client) OnSignal(name string, fn func(json.RawMessage)) {
c.sigMu.Lock()
c.signals[name] = fn
c.sigMu.Unlock()
}
// Request sends a request and waits for the engine's reply, returning the round
// trip time. The payload must be a JSON object carrying a "type".
func (c *Client) Request(payload any, timeout time.Duration) (json.RawMessage, time.Duration, error) {
id := atomic.AddInt64(&c.counter, 1)
ch := make(chan json.RawMessage, 1)
c.mu.Lock()
c.pending[id] = ch
c.mu.Unlock()
defer func() {
c.mu.Lock()
delete(c.pending, id)
c.mu.Unlock()
}()
start := time.Now()
if err := c.write([]any{payload, id, "R"}); err != nil {
return nil, 0, err
}
select {
case resp := <-ch:
return resp, time.Since(start), nil
case <-time.After(timeout):
return nil, 0, fmt.Errorf("request timed out")
}
}
// SendOnly sends a fire-and-forget message (the "R" string id path) for which the
// engine produces no reply.
func (c *Client) SendOnly(payload any) error {
return c.write([]any{payload, "R"})
}
func (c *Client) write(v any) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.conn.WriteMessage(websocket.TextMessage, b)
}
// Close shuts the connection down.
func (c *Client) Close() {
if c.closed.Swap(true) {
return
}
_ = c.conn.Close()
}
func (c *Client) readLoop() {
for {
_, data, err := c.conn.ReadMessage()
if err != nil {
return
}
var arr []json.RawMessage
if json.Unmarshal(data, &arr) != nil || len(arr) < 2 {
continue
}
// arr[1] is either a numeric id (a reply) or a string (a signal).
var num int64
if json.Unmarshal(arr[1], &num) == nil && looksNumeric(arr[1]) {
c.deliverReply(num, arr[0])
continue
}
var name string
if json.Unmarshal(arr[1], &name) == nil {
c.deliverSignal(name, arr[0])
}
}
}
func (c *Client) deliverReply(id int64, payload json.RawMessage) {
c.mu.Lock()
ch := c.pending[id]
c.mu.Unlock()
if ch != nil {
select {
case ch <- payload:
default:
}
}
}
func (c *Client) deliverSignal(name string, payload json.RawMessage) {
c.sigMu.Lock()
fn := c.signals[name]
c.sigMu.Unlock()
if fn != nil {
fn(payload)
}
}
// looksNumeric reports whether a raw JSON token is a number (so "5" is a reply id
// but "\"room/joined\"" is a signal name).
func looksNumeric(raw json.RawMessage) bool {
if len(raw) == 0 {
return false
}
c := raw[0]
return c == '-' || (c >= '0' && c <= '9')
}