768 lines
14 KiB
Go
768 lines
14 KiB
Go
package melody
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"testing/quick"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
var TestMsg = []byte("test")
|
|
|
|
type TestServer struct {
|
|
withKeys bool
|
|
m *Melody
|
|
}
|
|
|
|
func NewTestServerHandler(handler handleMessageFunc) *TestServer {
|
|
m := New()
|
|
m.HandleMessage(handler)
|
|
return &TestServer{
|
|
m: m,
|
|
}
|
|
}
|
|
|
|
func NewTestServer() *TestServer {
|
|
m := New()
|
|
return &TestServer{
|
|
m: m,
|
|
}
|
|
}
|
|
|
|
func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if s.withKeys {
|
|
s.m.HandleRequestWithKeys(w, r, make(map[string]any))
|
|
} else {
|
|
s.m.HandleRequest(w, r)
|
|
}
|
|
}
|
|
|
|
func NewDialer(url string) (*websocket.Conn, error) {
|
|
dialer := &websocket.Dialer{}
|
|
conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil)
|
|
return conn, err
|
|
}
|
|
|
|
func MustNewDialer(url string) *websocket.Conn {
|
|
conn, err := NewDialer(url)
|
|
|
|
if err != nil {
|
|
panic("could not dail websocket")
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func TestEcho(t *testing.T) {
|
|
ws := NewTestServerHandler(func(session *Session, msg []byte) {
|
|
session.Write(msg)
|
|
})
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
fn := func(msg string) bool {
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
|
|
_, ret, err := conn.ReadMessage()
|
|
|
|
assert.Nil(t, err)
|
|
|
|
assert.Equal(t, msg, string(ret))
|
|
|
|
return true
|
|
}
|
|
|
|
err := quick.Check(fn, nil)
|
|
|
|
assert.Nil(t, err)
|
|
}
|
|
|
|
func TestEchoBinary(t *testing.T) {
|
|
ws := NewTestServerHandler(func(session *Session, msg []byte) {
|
|
session.WriteBinary(msg)
|
|
})
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
fn := func(msg string) bool {
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
|
|
_, ret, err := conn.ReadMessage()
|
|
|
|
assert.Nil(t, err)
|
|
|
|
assert.True(t, bytes.Equal([]byte(msg), ret))
|
|
|
|
return true
|
|
}
|
|
|
|
err := quick.Check(fn, nil)
|
|
|
|
assert.Nil(t, err)
|
|
}
|
|
|
|
func TestWriteClosedServer(t *testing.T) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServer()
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
ws.m.HandleConnect(func(s *Session) {
|
|
s.Close()
|
|
})
|
|
|
|
ws.m.HandleDisconnect(func(s *Session) {
|
|
err := s.Write(TestMsg)
|
|
|
|
assert.NotNil(t, err)
|
|
close(done)
|
|
})
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
conn.ReadMessage()
|
|
defer conn.Close()
|
|
|
|
<-done
|
|
}
|
|
|
|
func TestWriteClosedClient(t *testing.T) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServer()
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
ws.m.HandleDisconnect(func(s *Session) {
|
|
err := s.Write(TestMsg)
|
|
|
|
assert.NotNil(t, err)
|
|
close(done)
|
|
})
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
conn.Close()
|
|
|
|
<-done
|
|
}
|
|
|
|
func TestUpgrader(t *testing.T) {
|
|
ws := NewTestServer()
|
|
ws.m.HandleMessage(func(session *Session, msg []byte) {
|
|
session.Write(msg)
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
ws.m.Upgrader = &websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool { return false },
|
|
}
|
|
|
|
_, err := NewDialer(server.URL)
|
|
|
|
assert.ErrorIs(t, err, websocket.ErrBadHandshake)
|
|
}
|
|
|
|
func TestBroadcast(t *testing.T) {
|
|
n := 10
|
|
msg := "test"
|
|
|
|
test := func(h func(*TestServer), w func(*websocket.Conn)) {
|
|
|
|
ws := NewTestServer()
|
|
|
|
h(ws)
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
listeners := make([]*websocket.Conn, n)
|
|
for i := range listeners {
|
|
listener := MustNewDialer(server.URL)
|
|
listeners[i] = listener
|
|
defer listeners[i].Close()
|
|
}
|
|
|
|
w(conn)
|
|
|
|
for _, listener := range listeners {
|
|
_, ret, err := listener.ReadMessage()
|
|
|
|
assert.Nil(t, err)
|
|
|
|
assert.Equal(t, msg, string(ret))
|
|
}
|
|
}
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessage(func(s *Session, msg []byte) {
|
|
ws.m.Broadcast(msg)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
})
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
|
|
ws.m.BroadcastBinary(msg)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
|
|
})
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessage(func(s *Session, msg []byte) {
|
|
ws.m.BroadcastFilter(msg, func(s *Session) bool {
|
|
return true
|
|
})
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
})
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
|
|
ws.m.BroadcastBinaryFilter(msg, func(s *Session) bool {
|
|
return true
|
|
})
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
|
|
})
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessage(func(s *Session, msg []byte) {
|
|
ws.m.BroadcastOthers(msg, s)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
})
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
|
|
ws.m.BroadcastBinaryOthers(msg, s)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
|
|
})
|
|
|
|
test(func(ws *TestServer) {
|
|
ws.m.HandleMessage(func(s *Session, msg []byte) {
|
|
ss, _ := ws.m.Sessions()
|
|
ws.m.BroadcastMultiple(msg, ss)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
})
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
ws := NewTestServer()
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
n := 10
|
|
|
|
conns := make([]*websocket.Conn, n)
|
|
for i := range conns {
|
|
conn := MustNewDialer(server.URL)
|
|
conns[i] = conn
|
|
defer conns[i].Close()
|
|
}
|
|
|
|
q := make(chan bool)
|
|
ws.m.HandleDisconnect(func(s *Session) {
|
|
q <- true
|
|
})
|
|
|
|
ws.m.Close()
|
|
|
|
for _, conn := range conns {
|
|
conn.ReadMessage()
|
|
}
|
|
|
|
assert.Zero(t, ws.m.Len())
|
|
|
|
m := 0
|
|
for range q {
|
|
m += 1
|
|
if m == n {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestLen(t *testing.T) {
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
connect := int(rand.Int31n(100))
|
|
disconnect := rand.Float32()
|
|
conns := make([]*websocket.Conn, connect)
|
|
|
|
defer func() {
|
|
for _, conn := range conns {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
}
|
|
}()
|
|
|
|
ws := NewTestServer()
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
disconnected := 0
|
|
for i := 0; i < connect; i++ {
|
|
conn := MustNewDialer(server.URL)
|
|
|
|
if rand.Float32() < disconnect {
|
|
conns[i] = nil
|
|
disconnected++
|
|
conn.Close()
|
|
continue
|
|
}
|
|
|
|
conns[i] = conn
|
|
}
|
|
|
|
time.Sleep(time.Millisecond)
|
|
|
|
connected := connect - disconnected
|
|
|
|
assert.Equal(t, ws.m.Len(), connected)
|
|
}
|
|
|
|
func TestSessions(t *testing.T) {
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
connect := int(rand.Int31n(100))
|
|
disconnect := rand.Float32()
|
|
conns := make([]*websocket.Conn, connect)
|
|
defer func() {
|
|
for _, conn := range conns {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
}
|
|
}()
|
|
|
|
ws := NewTestServer()
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
disconnected := 0
|
|
for i := 0; i < connect; i++ {
|
|
conn, err := NewDialer(server.URL)
|
|
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
if rand.Float32() < disconnect {
|
|
conns[i] = nil
|
|
disconnected++
|
|
conn.Close()
|
|
continue
|
|
}
|
|
|
|
conns[i] = conn
|
|
}
|
|
|
|
time.Sleep(time.Millisecond)
|
|
|
|
connected := connect - disconnected
|
|
|
|
ss, err := ws.m.Sessions()
|
|
|
|
assert.Nil(t, err)
|
|
|
|
assert.Equal(t, len(ss), connected)
|
|
}
|
|
|
|
func TestPingPong(t *testing.T) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServer()
|
|
ws.m.Config.PingPeriod = time.Millisecond
|
|
|
|
ws.m.HandlePong(func(s *Session) {
|
|
close(done)
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
go conn.NextReader()
|
|
|
|
<-done
|
|
}
|
|
|
|
func TestHandleClose(t *testing.T) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServer()
|
|
ws.m.Config.PingPeriod = time.Millisecond
|
|
|
|
ws.m.HandleClose(func(s *Session, code int, text string) error {
|
|
close(done)
|
|
return nil
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
|
|
conn.WriteMessage(websocket.CloseMessage, nil)
|
|
|
|
<-done
|
|
}
|
|
|
|
func TestHandleError(t *testing.T) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServer()
|
|
|
|
ws.m.HandleError(func(s *Session, err error) {
|
|
var closeError *websocket.CloseError
|
|
assert.ErrorAs(t, err, &closeError)
|
|
close(done)
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
|
|
conn.Close()
|
|
|
|
<-done
|
|
}
|
|
|
|
func TestHandleErrorWrite(t *testing.T) {
|
|
writeError := make(chan struct{})
|
|
disconnect := make(chan struct{})
|
|
|
|
ws := NewTestServer()
|
|
ws.m.Config.WriteWait = 0
|
|
|
|
ws.m.HandleConnect(func(s *Session) {
|
|
err := s.Write(TestMsg)
|
|
assert.Nil(t, err)
|
|
})
|
|
|
|
ws.m.HandleError(func(s *Session, err error) {
|
|
assert.NotNil(t, err)
|
|
|
|
if os.IsTimeout(err) {
|
|
close(writeError)
|
|
}
|
|
})
|
|
|
|
ws.m.HandleDisconnect(func(s *Session) {
|
|
close(disconnect)
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
go conn.NextReader()
|
|
|
|
<-writeError
|
|
<-disconnect
|
|
}
|
|
|
|
func TestErrClosed(t *testing.T) {
|
|
res := make(chan *Session)
|
|
|
|
ws := NewTestServer()
|
|
|
|
ws.m.HandleConnect(func(s *Session) {
|
|
ws.m.CloseWithMsg(TestMsg)
|
|
})
|
|
|
|
ws.m.HandleDisconnect(func(s *Session) {
|
|
res <- s
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
go conn.ReadMessage()
|
|
|
|
s := <-res
|
|
|
|
assert.True(t, s.IsClosed())
|
|
assert.True(t, ws.m.IsClosed())
|
|
_, err := ws.m.Sessions()
|
|
assert.ErrorIs(t, err, ErrClosed)
|
|
assert.ErrorIs(t, ws.m.Close(), ErrClosed)
|
|
assert.ErrorIs(t, ws.m.CloseWithMsg(TestMsg), ErrClosed)
|
|
|
|
assert.ErrorIs(t, ws.m.Broadcast(TestMsg), ErrClosed)
|
|
assert.ErrorIs(t, ws.m.BroadcastBinary(TestMsg), ErrClosed)
|
|
assert.ErrorIs(t, ws.m.BroadcastFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed)
|
|
assert.ErrorIs(t, ws.m.BroadcastBinaryFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed)
|
|
assert.ErrorIs(t, ws.m.HandleRequest(nil, nil), ErrClosed)
|
|
}
|
|
|
|
func TestErrSessionClosed(t *testing.T) {
|
|
res := make(chan *Session)
|
|
|
|
ws := NewTestServer()
|
|
|
|
ws.m.HandleConnect(func(s *Session) {
|
|
s.CloseWithMsg(TestMsg)
|
|
})
|
|
|
|
ws.m.HandleDisconnect(func(s *Session) {
|
|
res <- s
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
go conn.ReadMessage()
|
|
|
|
s := <-res
|
|
|
|
assert.True(t, s.IsClosed())
|
|
assert.ErrorIs(t, s.Write(TestMsg), ErrSessionClosed)
|
|
assert.ErrorIs(t, s.WriteBinary(TestMsg), ErrSessionClosed)
|
|
assert.ErrorIs(t, s.CloseWithMsg(TestMsg), ErrSessionClosed)
|
|
assert.ErrorIs(t, s.Close(), ErrSessionClosed)
|
|
assert.ErrorIs(t, ws.m.BroadcastMultiple(TestMsg, []*Session{s}), ErrSessionClosed)
|
|
|
|
assert.ErrorIs(t, s.writeRaw(nil), ErrWriteClosed)
|
|
s.writeMessage(nil)
|
|
}
|
|
|
|
func TestErrMessageBufferFull(t *testing.T) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServerHandler(func(session *Session, msg []byte) {
|
|
session.Write(msg)
|
|
session.Write(msg)
|
|
})
|
|
ws.m.Config.MessageBufferSize = 0
|
|
ws.m.HandleError(func(s *Session, err error) {
|
|
if errors.Is(err, ErrMessageBufferFull) {
|
|
close(done)
|
|
}
|
|
})
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
conn.WriteMessage(websocket.TextMessage, TestMsg)
|
|
|
|
<-done
|
|
}
|
|
|
|
func TestSessionKeys(t *testing.T) {
|
|
ws := NewTestServer()
|
|
|
|
ws.m.HandleConnect(func(session *Session) {
|
|
session.Set("stamp", time.Now().UnixNano())
|
|
})
|
|
ws.m.HandleMessage(func(session *Session, msg []byte) {
|
|
stamp := session.MustGet("stamp").(int64)
|
|
session.Write([]byte(strconv.Itoa(int(stamp))))
|
|
})
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
fn := func(msg string) bool {
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
conn.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
|
|
_, ret, err := conn.ReadMessage()
|
|
|
|
assert.Nil(t, err)
|
|
|
|
stamp, err := strconv.Atoi(string(ret))
|
|
|
|
assert.Nil(t, err)
|
|
|
|
diff := int(time.Now().UnixNano()) - stamp
|
|
|
|
assert.Greater(t, diff, 0)
|
|
|
|
return true
|
|
}
|
|
|
|
assert.Nil(t, quick.Check(fn, nil))
|
|
}
|
|
|
|
func TestSessionKeysConcurrent(t *testing.T) {
|
|
ss := make(chan *Session)
|
|
|
|
ws := NewTestServer()
|
|
|
|
ws.m.HandleConnect(func(s *Session) {
|
|
ss <- s
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
s := <-ss
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
s.Set("test", TestMsg)
|
|
|
|
v1, exists := s.Get("test")
|
|
|
|
assert.True(t, exists)
|
|
assert.Equal(t, v1, TestMsg)
|
|
|
|
v2 := s.MustGet("test")
|
|
|
|
assert.Equal(t, v1, v2)
|
|
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
s.UnSet("test")
|
|
|
|
_, exists := s.Get("test")
|
|
|
|
assert.False(t, exists)
|
|
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestMisc(t *testing.T) {
|
|
res := make(chan *Session)
|
|
|
|
ws := NewTestServer()
|
|
|
|
ws.m.HandleConnect(func(s *Session) {
|
|
res <- s
|
|
})
|
|
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
go conn.ReadMessage()
|
|
|
|
s := <-res
|
|
|
|
assert.Contains(t, s.LocalAddr().String(), "127.0.0.1")
|
|
assert.Contains(t, s.RemoteAddr().String(), "127.0.0.1")
|
|
assert.Equal(t, FormatCloseMessage(websocket.CloseMessage, "test"), websocket.FormatCloseMessage(websocket.CloseMessage, "test"))
|
|
assert.Panics(t, func() {
|
|
s.MustGet("test")
|
|
})
|
|
}
|
|
|
|
func TestHandleSentMessage(t *testing.T) {
|
|
test := func(h func(*TestServer, chan bool), w func(*websocket.Conn)) {
|
|
done := make(chan bool)
|
|
|
|
ws := NewTestServer()
|
|
server := httptest.NewServer(ws)
|
|
defer server.Close()
|
|
|
|
h(ws, done)
|
|
|
|
conn := MustNewDialer(server.URL)
|
|
defer conn.Close()
|
|
|
|
w(conn)
|
|
|
|
<-done
|
|
}
|
|
|
|
test(func(ws *TestServer, done chan bool) {
|
|
ws.m.HandleMessage(func(s *Session, msg []byte) {
|
|
s.Write(msg)
|
|
})
|
|
|
|
ws.m.HandleSentMessage(func(s *Session, msg []byte) {
|
|
assert.Equal(t, TestMsg, msg)
|
|
close(done)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.TextMessage, TestMsg)
|
|
})
|
|
|
|
test(func(ws *TestServer, done chan bool) {
|
|
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
|
|
s.WriteBinary(msg)
|
|
})
|
|
|
|
ws.m.HandleSentMessageBinary(func(s *Session, msg []byte) {
|
|
assert.Equal(t, TestMsg, msg)
|
|
close(done)
|
|
})
|
|
}, func(conn *websocket.Conn) {
|
|
conn.WriteMessage(websocket.BinaryMessage, TestMsg)
|
|
})
|
|
}
|