183 lines
4.1 KiB
Go
183 lines
4.1 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Client is a minimal WSTS speaker used by the load tester. It mirrors the
|
|
// frontend SDK's framing: requests carry a numeric id and a trailing "R", and the
|
|
// engine replies with [payload, id, "E"]; server-initiated messages arrive as
|
|
// [payload, name].
|
|
type Client struct {
|
|
ID string // assigned by the server's "id" signal
|
|
|
|
conn *websocket.Conn
|
|
counter int64
|
|
|
|
mu sync.Mutex
|
|
pending map[int64]chan json.RawMessage
|
|
|
|
sigMu sync.Mutex
|
|
signals map[string]func(json.RawMessage)
|
|
|
|
writeMu sync.Mutex // gorilla allows only one concurrent writer
|
|
|
|
closed atomic.Bool
|
|
idReady chan struct{}
|
|
}
|
|
|
|
// Dial connects to the engine at url and starts the read loop.
|
|
func Dial(url string) (*Client, error) {
|
|
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c := &Client{
|
|
conn: conn,
|
|
pending: make(map[int64]chan json.RawMessage),
|
|
signals: make(map[string]func(json.RawMessage)),
|
|
idReady: make(chan struct{}),
|
|
}
|
|
// Capture our socket id from the YourID signal.
|
|
c.OnSignal("id", func(raw json.RawMessage) {
|
|
var p struct {
|
|
Value string `json:"value"`
|
|
}
|
|
if json.Unmarshal(raw, &p) == nil && p.Value != "" {
|
|
c.ID = p.Value
|
|
close(c.idReady)
|
|
}
|
|
})
|
|
go c.readLoop()
|
|
return c, nil
|
|
}
|
|
|
|
// WaitID blocks until the server has told us our socket id (or the timeout fires).
|
|
func (c *Client) WaitID(timeout time.Duration) error {
|
|
select {
|
|
case <-c.idReady:
|
|
return nil
|
|
case <-time.After(timeout):
|
|
return fmt.Errorf("timed out waiting for socket id")
|
|
}
|
|
}
|
|
|
|
// OnSignal registers a handler for a server-initiated signal name.
|
|
func (c *Client) OnSignal(name string, fn func(json.RawMessage)) {
|
|
c.sigMu.Lock()
|
|
c.signals[name] = fn
|
|
c.sigMu.Unlock()
|
|
}
|
|
|
|
// Request sends a request and waits for the engine's reply, returning the round
|
|
// trip time. The payload must be a JSON object carrying a "type".
|
|
func (c *Client) Request(payload any, timeout time.Duration) (json.RawMessage, time.Duration, error) {
|
|
id := atomic.AddInt64(&c.counter, 1)
|
|
ch := make(chan json.RawMessage, 1)
|
|
c.mu.Lock()
|
|
c.pending[id] = ch
|
|
c.mu.Unlock()
|
|
|
|
defer func() {
|
|
c.mu.Lock()
|
|
delete(c.pending, id)
|
|
c.mu.Unlock()
|
|
}()
|
|
|
|
start := time.Now()
|
|
if err := c.write([]any{payload, id, "R"}); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
select {
|
|
case resp := <-ch:
|
|
return resp, time.Since(start), nil
|
|
case <-time.After(timeout):
|
|
return nil, 0, fmt.Errorf("request timed out")
|
|
}
|
|
}
|
|
|
|
// SendOnly sends a fire-and-forget message (the "R" string id path) for which the
|
|
// engine produces no reply.
|
|
func (c *Client) SendOnly(payload any) error {
|
|
return c.write([]any{payload, "R"})
|
|
}
|
|
|
|
func (c *Client) write(v any) error {
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
return c.conn.WriteMessage(websocket.TextMessage, b)
|
|
}
|
|
|
|
// Close shuts the connection down.
|
|
func (c *Client) Close() {
|
|
if c.closed.Swap(true) {
|
|
return
|
|
}
|
|
_ = c.conn.Close()
|
|
}
|
|
|
|
func (c *Client) readLoop() {
|
|
for {
|
|
_, data, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
var arr []json.RawMessage
|
|
if json.Unmarshal(data, &arr) != nil || len(arr) < 2 {
|
|
continue
|
|
}
|
|
|
|
// arr[1] is either a numeric id (a reply) or a string (a signal).
|
|
var num int64
|
|
if json.Unmarshal(arr[1], &num) == nil && looksNumeric(arr[1]) {
|
|
c.deliverReply(num, arr[0])
|
|
continue
|
|
}
|
|
var name string
|
|
if json.Unmarshal(arr[1], &name) == nil {
|
|
c.deliverSignal(name, arr[0])
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) deliverReply(id int64, payload json.RawMessage) {
|
|
c.mu.Lock()
|
|
ch := c.pending[id]
|
|
c.mu.Unlock()
|
|
if ch != nil {
|
|
select {
|
|
case ch <- payload:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) deliverSignal(name string, payload json.RawMessage) {
|
|
c.sigMu.Lock()
|
|
fn := c.signals[name]
|
|
c.sigMu.Unlock()
|
|
if fn != nil {
|
|
fn(payload)
|
|
}
|
|
}
|
|
|
|
// looksNumeric reports whether a raw JSON token is a number (so "5" is a reply id
|
|
// but "\"room/joined\"" is a signal name).
|
|
func looksNumeric(raw json.RawMessage) bool {
|
|
if len(raw) == 0 {
|
|
return false
|
|
}
|
|
c := raw[0]
|
|
return c == '-' || (c >= '0' && c <= '9')
|
|
}
|