package melody import ( "bytes" "errors" "math/rand" "net/http" "net/http/httptest" "os" "strconv" "strings" "sync" "testing" "testing/quick" "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) var TestMsg = []byte("test") type TestServer struct { withKeys bool m *Melody } func NewTestServerHandler(handler handleMessageFunc) *TestServer { m := New() m.HandleMessage(handler) return &TestServer{ m: m, } } func NewTestServer() *TestServer { m := New() return &TestServer{ m: m, } } func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if s.withKeys { s.m.HandleRequestWithKeys(w, r, make(map[string]any)) } else { s.m.HandleRequest(w, r) } } func NewDialer(url string) (*websocket.Conn, error) { dialer := &websocket.Dialer{} conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil) return conn, err } func MustNewDialer(url string) *websocket.Conn { conn, err := NewDialer(url) if err != nil { panic("could not dail websocket") } return conn } func TestEcho(t *testing.T) { ws := NewTestServerHandler(func(session *Session, msg []byte) { session.Write(msg) }) server := httptest.NewServer(ws) defer server.Close() fn := func(msg string) bool { conn := MustNewDialer(server.URL) defer conn.Close() conn.WriteMessage(websocket.TextMessage, []byte(msg)) _, ret, err := conn.ReadMessage() assert.Nil(t, err) assert.Equal(t, msg, string(ret)) return true } err := quick.Check(fn, nil) assert.Nil(t, err) } func TestEchoBinary(t *testing.T) { ws := NewTestServerHandler(func(session *Session, msg []byte) { session.WriteBinary(msg) }) server := httptest.NewServer(ws) defer server.Close() fn := func(msg string) bool { conn := MustNewDialer(server.URL) defer conn.Close() conn.WriteMessage(websocket.TextMessage, []byte(msg)) _, ret, err := conn.ReadMessage() assert.Nil(t, err) assert.True(t, bytes.Equal([]byte(msg), ret)) return true } err := quick.Check(fn, nil) assert.Nil(t, err) } func TestWriteClosedServer(t *testing.T) { done := make(chan bool) ws := NewTestServer() server := httptest.NewServer(ws) defer server.Close() ws.m.HandleConnect(func(s *Session) { s.Close() }) ws.m.HandleDisconnect(func(s *Session) { err := s.Write(TestMsg) assert.NotNil(t, err) close(done) }) conn := MustNewDialer(server.URL) conn.ReadMessage() defer conn.Close() <-done } func TestWriteClosedClient(t *testing.T) { done := make(chan bool) ws := NewTestServer() server := httptest.NewServer(ws) defer server.Close() ws.m.HandleDisconnect(func(s *Session) { err := s.Write(TestMsg) assert.NotNil(t, err) close(done) }) conn := MustNewDialer(server.URL) conn.Close() <-done } func TestUpgrader(t *testing.T) { ws := NewTestServer() ws.m.HandleMessage(func(session *Session, msg []byte) { session.Write(msg) }) server := httptest.NewServer(ws) defer server.Close() ws.m.Upgrader = &websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return false }, } _, err := NewDialer(server.URL) assert.ErrorIs(t, err, websocket.ErrBadHandshake) } func TestBroadcast(t *testing.T) { n := 10 msg := "test" test := func(h func(*TestServer), w func(*websocket.Conn)) { ws := NewTestServer() h(ws) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() listeners := make([]*websocket.Conn, n) for i := range listeners { listener := MustNewDialer(server.URL) listeners[i] = listener defer listeners[i].Close() } w(conn) for _, listener := range listeners { _, ret, err := listener.ReadMessage() assert.Nil(t, err) assert.Equal(t, msg, string(ret)) } } test(func(ws *TestServer) { ws.m.HandleMessage(func(s *Session, msg []byte) { ws.m.Broadcast(msg) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.TextMessage, []byte(msg)) }) test(func(ws *TestServer) { ws.m.HandleMessageBinary(func(s *Session, msg []byte) { ws.m.BroadcastBinary(msg) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) }) test(func(ws *TestServer) { ws.m.HandleMessage(func(s *Session, msg []byte) { ws.m.BroadcastFilter(msg, func(s *Session) bool { return true }) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.TextMessage, []byte(msg)) }) test(func(ws *TestServer) { ws.m.HandleMessageBinary(func(s *Session, msg []byte) { ws.m.BroadcastBinaryFilter(msg, func(s *Session) bool { return true }) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) }) test(func(ws *TestServer) { ws.m.HandleMessage(func(s *Session, msg []byte) { ws.m.BroadcastOthers(msg, s) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.TextMessage, []byte(msg)) }) test(func(ws *TestServer) { ws.m.HandleMessageBinary(func(s *Session, msg []byte) { ws.m.BroadcastBinaryOthers(msg, s) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) }) test(func(ws *TestServer) { ws.m.HandleMessage(func(s *Session, msg []byte) { ss, _ := ws.m.Sessions() ws.m.BroadcastMultiple(msg, ss) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.TextMessage, []byte(msg)) }) } func TestClose(t *testing.T) { ws := NewTestServer() server := httptest.NewServer(ws) defer server.Close() n := 10 conns := make([]*websocket.Conn, n) for i := range conns { conn := MustNewDialer(server.URL) conns[i] = conn defer conns[i].Close() } q := make(chan bool) ws.m.HandleDisconnect(func(s *Session) { q <- true }) ws.m.Close() for _, conn := range conns { conn.ReadMessage() } assert.Zero(t, ws.m.Len()) m := 0 for range q { m += 1 if m == n { break } } } func TestLen(t *testing.T) { rand.Seed(time.Now().UnixNano()) connect := int(rand.Int31n(100)) disconnect := rand.Float32() conns := make([]*websocket.Conn, connect) defer func() { for _, conn := range conns { if conn != nil { conn.Close() } } }() ws := NewTestServer() server := httptest.NewServer(ws) defer server.Close() disconnected := 0 for i := 0; i < connect; i++ { conn := MustNewDialer(server.URL) if rand.Float32() < disconnect { conns[i] = nil disconnected++ conn.Close() continue } conns[i] = conn } time.Sleep(time.Millisecond) connected := connect - disconnected assert.Equal(t, ws.m.Len(), connected) } func TestSessions(t *testing.T) { rand.Seed(time.Now().UnixNano()) connect := int(rand.Int31n(100)) disconnect := rand.Float32() conns := make([]*websocket.Conn, connect) defer func() { for _, conn := range conns { if conn != nil { conn.Close() } } }() ws := NewTestServer() server := httptest.NewServer(ws) defer server.Close() disconnected := 0 for i := 0; i < connect; i++ { conn, err := NewDialer(server.URL) if err != nil { t.Error(err) } if rand.Float32() < disconnect { conns[i] = nil disconnected++ conn.Close() continue } conns[i] = conn } time.Sleep(time.Millisecond) connected := connect - disconnected ss, err := ws.m.Sessions() assert.Nil(t, err) assert.Equal(t, len(ss), connected) } func TestPingPong(t *testing.T) { done := make(chan bool) ws := NewTestServer() ws.m.Config.PingPeriod = time.Millisecond ws.m.HandlePong(func(s *Session) { close(done) }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() go conn.NextReader() <-done } func TestHandleClose(t *testing.T) { done := make(chan bool) ws := NewTestServer() ws.m.Config.PingPeriod = time.Millisecond ws.m.HandleClose(func(s *Session, code int, text string) error { close(done) return nil }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) conn.WriteMessage(websocket.CloseMessage, nil) <-done } func TestHandleError(t *testing.T) { done := make(chan bool) ws := NewTestServer() ws.m.HandleError(func(s *Session, err error) { var closeError *websocket.CloseError assert.ErrorAs(t, err, &closeError) close(done) }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) conn.Close() <-done } func TestHandleErrorWrite(t *testing.T) { writeError := make(chan struct{}) disconnect := make(chan struct{}) ws := NewTestServer() ws.m.Config.WriteWait = 0 ws.m.HandleConnect(func(s *Session) { err := s.Write(TestMsg) assert.Nil(t, err) }) ws.m.HandleError(func(s *Session, err error) { assert.NotNil(t, err) if os.IsTimeout(err) { close(writeError) } }) ws.m.HandleDisconnect(func(s *Session) { close(disconnect) }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() go conn.NextReader() <-writeError <-disconnect } func TestErrClosed(t *testing.T) { res := make(chan *Session) ws := NewTestServer() ws.m.HandleConnect(func(s *Session) { ws.m.CloseWithMsg(TestMsg) }) ws.m.HandleDisconnect(func(s *Session) { res <- s }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() go conn.ReadMessage() s := <-res assert.True(t, s.IsClosed()) assert.True(t, ws.m.IsClosed()) _, err := ws.m.Sessions() assert.ErrorIs(t, err, ErrClosed) assert.ErrorIs(t, ws.m.Close(), ErrClosed) assert.ErrorIs(t, ws.m.CloseWithMsg(TestMsg), ErrClosed) assert.ErrorIs(t, ws.m.Broadcast(TestMsg), ErrClosed) assert.ErrorIs(t, ws.m.BroadcastBinary(TestMsg), ErrClosed) assert.ErrorIs(t, ws.m.BroadcastFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed) assert.ErrorIs(t, ws.m.BroadcastBinaryFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed) assert.ErrorIs(t, ws.m.HandleRequest(nil, nil), ErrClosed) } func TestErrSessionClosed(t *testing.T) { res := make(chan *Session) ws := NewTestServer() ws.m.HandleConnect(func(s *Session) { s.CloseWithMsg(TestMsg) }) ws.m.HandleDisconnect(func(s *Session) { res <- s }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() go conn.ReadMessage() s := <-res assert.True(t, s.IsClosed()) assert.ErrorIs(t, s.Write(TestMsg), ErrSessionClosed) assert.ErrorIs(t, s.WriteBinary(TestMsg), ErrSessionClosed) assert.ErrorIs(t, s.CloseWithMsg(TestMsg), ErrSessionClosed) assert.ErrorIs(t, s.Close(), ErrSessionClosed) assert.ErrorIs(t, ws.m.BroadcastMultiple(TestMsg, []*Session{s}), ErrSessionClosed) assert.ErrorIs(t, s.writeRaw(nil), ErrWriteClosed) s.writeMessage(nil) } func TestErrMessageBufferFull(t *testing.T) { done := make(chan bool) ws := NewTestServerHandler(func(session *Session, msg []byte) { session.Write(msg) session.Write(msg) }) ws.m.Config.MessageBufferSize = 0 ws.m.HandleError(func(s *Session, err error) { if errors.Is(err, ErrMessageBufferFull) { close(done) } }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() conn.WriteMessage(websocket.TextMessage, TestMsg) <-done } func TestSessionKeys(t *testing.T) { ws := NewTestServer() ws.m.HandleConnect(func(session *Session) { session.Set("stamp", time.Now().UnixNano()) }) ws.m.HandleMessage(func(session *Session, msg []byte) { stamp := session.MustGet("stamp").(int64) session.Write([]byte(strconv.Itoa(int(stamp)))) }) server := httptest.NewServer(ws) defer server.Close() fn := func(msg string) bool { conn := MustNewDialer(server.URL) defer conn.Close() conn.WriteMessage(websocket.TextMessage, []byte(msg)) _, ret, err := conn.ReadMessage() assert.Nil(t, err) stamp, err := strconv.Atoi(string(ret)) assert.Nil(t, err) diff := int(time.Now().UnixNano()) - stamp assert.Greater(t, diff, 0) return true } assert.Nil(t, quick.Check(fn, nil)) } func TestSessionKeysConcurrent(t *testing.T) { ss := make(chan *Session) ws := NewTestServer() ws.m.HandleConnect(func(s *Session) { ss <- s }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() s := <-ss var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { s.Set("test", TestMsg) v1, exists := s.Get("test") assert.True(t, exists) assert.Equal(t, v1, TestMsg) v2 := s.MustGet("test") assert.Equal(t, v1, v2) wg.Done() }() } wg.Wait() for i := 0; i < 100; i++ { wg.Add(1) go func() { s.UnSet("test") _, exists := s.Get("test") assert.False(t, exists) wg.Done() }() } wg.Wait() } func TestMisc(t *testing.T) { res := make(chan *Session) ws := NewTestServer() ws.m.HandleConnect(func(s *Session) { res <- s }) server := httptest.NewServer(ws) defer server.Close() conn := MustNewDialer(server.URL) defer conn.Close() go conn.ReadMessage() s := <-res assert.Contains(t, s.LocalAddr().String(), "127.0.0.1") assert.Contains(t, s.RemoteAddr().String(), "127.0.0.1") assert.Equal(t, FormatCloseMessage(websocket.CloseMessage, "test"), websocket.FormatCloseMessage(websocket.CloseMessage, "test")) assert.Panics(t, func() { s.MustGet("test") }) } func TestHandleSentMessage(t *testing.T) { test := func(h func(*TestServer, chan bool), w func(*websocket.Conn)) { done := make(chan bool) ws := NewTestServer() server := httptest.NewServer(ws) defer server.Close() h(ws, done) conn := MustNewDialer(server.URL) defer conn.Close() w(conn) <-done } test(func(ws *TestServer, done chan bool) { ws.m.HandleMessage(func(s *Session, msg []byte) { s.Write(msg) }) ws.m.HandleSentMessage(func(s *Session, msg []byte) { assert.Equal(t, TestMsg, msg) close(done) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.TextMessage, TestMsg) }) test(func(ws *TestServer, done chan bool) { ws.m.HandleMessageBinary(func(s *Session, msg []byte) { s.WriteBinary(msg) }) ws.m.HandleSentMessageBinary(func(s *Session, msg []byte) { assert.Equal(t, TestMsg, msg) close(done) }) }, func(conn *websocket.Conn) { conn.WriteMessage(websocket.BinaryMessage, TestMsg) }) }