service/library/melody/melody_test.go

768 lines
14 KiB
Go

package melody
import (
"bytes"
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"sync"
"testing"
"testing/quick"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
var TestMsg = []byte("test")
type TestServer struct {
withKeys bool
m *Melody
}
func NewTestServerHandler(handler handleMessageFunc) *TestServer {
m := New()
m.HandleMessage(handler)
return &TestServer{
m: m,
}
}
func NewTestServer() *TestServer {
m := New()
return &TestServer{
m: m,
}
}
func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.withKeys {
s.m.HandleRequestWithKeys(w, r, make(map[string]any))
} else {
s.m.HandleRequest(w, r)
}
}
func NewDialer(url string) (*websocket.Conn, error) {
dialer := &websocket.Dialer{}
conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil)
return conn, err
}
func MustNewDialer(url string) *websocket.Conn {
conn, err := NewDialer(url)
if err != nil {
panic("could not dail websocket")
}
return conn
}
func TestEcho(t *testing.T) {
ws := NewTestServerHandler(func(session *Session, msg []byte) {
session.Write(msg)
})
server := httptest.NewServer(ws)
defer server.Close()
fn := func(msg string) bool {
conn := MustNewDialer(server.URL)
defer conn.Close()
conn.WriteMessage(websocket.TextMessage, []byte(msg))
_, ret, err := conn.ReadMessage()
assert.Nil(t, err)
assert.Equal(t, msg, string(ret))
return true
}
err := quick.Check(fn, nil)
assert.Nil(t, err)
}
func TestEchoBinary(t *testing.T) {
ws := NewTestServerHandler(func(session *Session, msg []byte) {
session.WriteBinary(msg)
})
server := httptest.NewServer(ws)
defer server.Close()
fn := func(msg string) bool {
conn := MustNewDialer(server.URL)
defer conn.Close()
conn.WriteMessage(websocket.TextMessage, []byte(msg))
_, ret, err := conn.ReadMessage()
assert.Nil(t, err)
assert.True(t, bytes.Equal([]byte(msg), ret))
return true
}
err := quick.Check(fn, nil)
assert.Nil(t, err)
}
func TestWriteClosedServer(t *testing.T) {
done := make(chan bool)
ws := NewTestServer()
server := httptest.NewServer(ws)
defer server.Close()
ws.m.HandleConnect(func(s *Session) {
s.Close()
})
ws.m.HandleDisconnect(func(s *Session) {
err := s.Write(TestMsg)
assert.NotNil(t, err)
close(done)
})
conn := MustNewDialer(server.URL)
conn.ReadMessage()
defer conn.Close()
<-done
}
func TestWriteClosedClient(t *testing.T) {
done := make(chan bool)
ws := NewTestServer()
server := httptest.NewServer(ws)
defer server.Close()
ws.m.HandleDisconnect(func(s *Session) {
err := s.Write(TestMsg)
assert.NotNil(t, err)
close(done)
})
conn := MustNewDialer(server.URL)
conn.Close()
<-done
}
func TestUpgrader(t *testing.T) {
ws := NewTestServer()
ws.m.HandleMessage(func(session *Session, msg []byte) {
session.Write(msg)
})
server := httptest.NewServer(ws)
defer server.Close()
ws.m.Upgrader = &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return false },
}
_, err := NewDialer(server.URL)
assert.ErrorIs(t, err, websocket.ErrBadHandshake)
}
func TestBroadcast(t *testing.T) {
n := 10
msg := "test"
test := func(h func(*TestServer), w func(*websocket.Conn)) {
ws := NewTestServer()
h(ws)
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
listeners := make([]*websocket.Conn, n)
for i := range listeners {
listener := MustNewDialer(server.URL)
listeners[i] = listener
defer listeners[i].Close()
}
w(conn)
for _, listener := range listeners {
_, ret, err := listener.ReadMessage()
assert.Nil(t, err)
assert.Equal(t, msg, string(ret))
}
}
test(func(ws *TestServer) {
ws.m.HandleMessage(func(s *Session, msg []byte) {
ws.m.Broadcast(msg)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.TextMessage, []byte(msg))
})
test(func(ws *TestServer) {
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
ws.m.BroadcastBinary(msg)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
})
test(func(ws *TestServer) {
ws.m.HandleMessage(func(s *Session, msg []byte) {
ws.m.BroadcastFilter(msg, func(s *Session) bool {
return true
})
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.TextMessage, []byte(msg))
})
test(func(ws *TestServer) {
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
ws.m.BroadcastBinaryFilter(msg, func(s *Session) bool {
return true
})
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
})
test(func(ws *TestServer) {
ws.m.HandleMessage(func(s *Session, msg []byte) {
ws.m.BroadcastOthers(msg, s)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.TextMessage, []byte(msg))
})
test(func(ws *TestServer) {
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
ws.m.BroadcastBinaryOthers(msg, s)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
})
test(func(ws *TestServer) {
ws.m.HandleMessage(func(s *Session, msg []byte) {
ss, _ := ws.m.Sessions()
ws.m.BroadcastMultiple(msg, ss)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.TextMessage, []byte(msg))
})
}
func TestClose(t *testing.T) {
ws := NewTestServer()
server := httptest.NewServer(ws)
defer server.Close()
n := 10
conns := make([]*websocket.Conn, n)
for i := range conns {
conn := MustNewDialer(server.URL)
conns[i] = conn
defer conns[i].Close()
}
q := make(chan bool)
ws.m.HandleDisconnect(func(s *Session) {
q <- true
})
ws.m.Close()
for _, conn := range conns {
conn.ReadMessage()
}
assert.Zero(t, ws.m.Len())
m := 0
for range q {
m += 1
if m == n {
break
}
}
}
func TestLen(t *testing.T) {
rand.Seed(time.Now().UnixNano())
connect := int(rand.Int31n(100))
disconnect := rand.Float32()
conns := make([]*websocket.Conn, connect)
defer func() {
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
}()
ws := NewTestServer()
server := httptest.NewServer(ws)
defer server.Close()
disconnected := 0
for i := 0; i < connect; i++ {
conn := MustNewDialer(server.URL)
if rand.Float32() < disconnect {
conns[i] = nil
disconnected++
conn.Close()
continue
}
conns[i] = conn
}
time.Sleep(time.Millisecond)
connected := connect - disconnected
assert.Equal(t, ws.m.Len(), connected)
}
func TestSessions(t *testing.T) {
rand.Seed(time.Now().UnixNano())
connect := int(rand.Int31n(100))
disconnect := rand.Float32()
conns := make([]*websocket.Conn, connect)
defer func() {
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
}()
ws := NewTestServer()
server := httptest.NewServer(ws)
defer server.Close()
disconnected := 0
for i := 0; i < connect; i++ {
conn, err := NewDialer(server.URL)
if err != nil {
t.Error(err)
}
if rand.Float32() < disconnect {
conns[i] = nil
disconnected++
conn.Close()
continue
}
conns[i] = conn
}
time.Sleep(time.Millisecond)
connected := connect - disconnected
ss, err := ws.m.Sessions()
assert.Nil(t, err)
assert.Equal(t, len(ss), connected)
}
func TestPingPong(t *testing.T) {
done := make(chan bool)
ws := NewTestServer()
ws.m.Config.PingPeriod = time.Millisecond
ws.m.HandlePong(func(s *Session) {
close(done)
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
go conn.NextReader()
<-done
}
func TestHandleClose(t *testing.T) {
done := make(chan bool)
ws := NewTestServer()
ws.m.Config.PingPeriod = time.Millisecond
ws.m.HandleClose(func(s *Session, code int, text string) error {
close(done)
return nil
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
conn.WriteMessage(websocket.CloseMessage, nil)
<-done
}
func TestHandleError(t *testing.T) {
done := make(chan bool)
ws := NewTestServer()
ws.m.HandleError(func(s *Session, err error) {
var closeError *websocket.CloseError
assert.ErrorAs(t, err, &closeError)
close(done)
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
conn.Close()
<-done
}
func TestHandleErrorWrite(t *testing.T) {
writeError := make(chan struct{})
disconnect := make(chan struct{})
ws := NewTestServer()
ws.m.Config.WriteWait = 0
ws.m.HandleConnect(func(s *Session) {
err := s.Write(TestMsg)
assert.Nil(t, err)
})
ws.m.HandleError(func(s *Session, err error) {
assert.NotNil(t, err)
if os.IsTimeout(err) {
close(writeError)
}
})
ws.m.HandleDisconnect(func(s *Session) {
close(disconnect)
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
go conn.NextReader()
<-writeError
<-disconnect
}
func TestErrClosed(t *testing.T) {
res := make(chan *Session)
ws := NewTestServer()
ws.m.HandleConnect(func(s *Session) {
ws.m.CloseWithMsg(TestMsg)
})
ws.m.HandleDisconnect(func(s *Session) {
res <- s
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
go conn.ReadMessage()
s := <-res
assert.True(t, s.IsClosed())
assert.True(t, ws.m.IsClosed())
_, err := ws.m.Sessions()
assert.ErrorIs(t, err, ErrClosed)
assert.ErrorIs(t, ws.m.Close(), ErrClosed)
assert.ErrorIs(t, ws.m.CloseWithMsg(TestMsg), ErrClosed)
assert.ErrorIs(t, ws.m.Broadcast(TestMsg), ErrClosed)
assert.ErrorIs(t, ws.m.BroadcastBinary(TestMsg), ErrClosed)
assert.ErrorIs(t, ws.m.BroadcastFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed)
assert.ErrorIs(t, ws.m.BroadcastBinaryFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed)
assert.ErrorIs(t, ws.m.HandleRequest(nil, nil), ErrClosed)
}
func TestErrSessionClosed(t *testing.T) {
res := make(chan *Session)
ws := NewTestServer()
ws.m.HandleConnect(func(s *Session) {
s.CloseWithMsg(TestMsg)
})
ws.m.HandleDisconnect(func(s *Session) {
res <- s
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
go conn.ReadMessage()
s := <-res
assert.True(t, s.IsClosed())
assert.ErrorIs(t, s.Write(TestMsg), ErrSessionClosed)
assert.ErrorIs(t, s.WriteBinary(TestMsg), ErrSessionClosed)
assert.ErrorIs(t, s.CloseWithMsg(TestMsg), ErrSessionClosed)
assert.ErrorIs(t, s.Close(), ErrSessionClosed)
assert.ErrorIs(t, ws.m.BroadcastMultiple(TestMsg, []*Session{s}), ErrSessionClosed)
assert.ErrorIs(t, s.writeRaw(nil), ErrWriteClosed)
s.writeMessage(nil)
}
func TestErrMessageBufferFull(t *testing.T) {
done := make(chan bool)
ws := NewTestServerHandler(func(session *Session, msg []byte) {
session.Write(msg)
session.Write(msg)
})
ws.m.Config.MessageBufferSize = 0
ws.m.HandleError(func(s *Session, err error) {
if errors.Is(err, ErrMessageBufferFull) {
close(done)
}
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
conn.WriteMessage(websocket.TextMessage, TestMsg)
<-done
}
func TestSessionKeys(t *testing.T) {
ws := NewTestServer()
ws.m.HandleConnect(func(session *Session) {
session.Set("stamp", time.Now().UnixNano())
})
ws.m.HandleMessage(func(session *Session, msg []byte) {
stamp := session.MustGet("stamp").(int64)
session.Write([]byte(strconv.Itoa(int(stamp))))
})
server := httptest.NewServer(ws)
defer server.Close()
fn := func(msg string) bool {
conn := MustNewDialer(server.URL)
defer conn.Close()
conn.WriteMessage(websocket.TextMessage, []byte(msg))
_, ret, err := conn.ReadMessage()
assert.Nil(t, err)
stamp, err := strconv.Atoi(string(ret))
assert.Nil(t, err)
diff := int(time.Now().UnixNano()) - stamp
assert.Greater(t, diff, 0)
return true
}
assert.Nil(t, quick.Check(fn, nil))
}
func TestSessionKeysConcurrent(t *testing.T) {
ss := make(chan *Session)
ws := NewTestServer()
ws.m.HandleConnect(func(s *Session) {
ss <- s
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
s := <-ss
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
s.Set("test", TestMsg)
v1, exists := s.Get("test")
assert.True(t, exists)
assert.Equal(t, v1, TestMsg)
v2 := s.MustGet("test")
assert.Equal(t, v1, v2)
wg.Done()
}()
}
wg.Wait()
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
s.UnSet("test")
_, exists := s.Get("test")
assert.False(t, exists)
wg.Done()
}()
}
wg.Wait()
}
func TestMisc(t *testing.T) {
res := make(chan *Session)
ws := NewTestServer()
ws.m.HandleConnect(func(s *Session) {
res <- s
})
server := httptest.NewServer(ws)
defer server.Close()
conn := MustNewDialer(server.URL)
defer conn.Close()
go conn.ReadMessage()
s := <-res
assert.Contains(t, s.LocalAddr().String(), "127.0.0.1")
assert.Contains(t, s.RemoteAddr().String(), "127.0.0.1")
assert.Equal(t, FormatCloseMessage(websocket.CloseMessage, "test"), websocket.FormatCloseMessage(websocket.CloseMessage, "test"))
assert.Panics(t, func() {
s.MustGet("test")
})
}
func TestHandleSentMessage(t *testing.T) {
test := func(h func(*TestServer, chan bool), w func(*websocket.Conn)) {
done := make(chan bool)
ws := NewTestServer()
server := httptest.NewServer(ws)
defer server.Close()
h(ws, done)
conn := MustNewDialer(server.URL)
defer conn.Close()
w(conn)
<-done
}
test(func(ws *TestServer, done chan bool) {
ws.m.HandleMessage(func(s *Session, msg []byte) {
s.Write(msg)
})
ws.m.HandleSentMessage(func(s *Session, msg []byte) {
assert.Equal(t, TestMsg, msg)
close(done)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.TextMessage, TestMsg)
})
test(func(ws *TestServer, done chan bool) {
ws.m.HandleMessageBinary(func(s *Session, msg []byte) {
s.WriteBinary(msg)
})
ws.m.HandleSentMessageBinary(func(s *Session, msg []byte) {
assert.Equal(t, TestMsg, msg)
close(done)
})
}, func(conn *websocket.Conn) {
conn.WriteMessage(websocket.BinaryMessage, TestMsg)
})
}