389 lines
8.1 KiB
Go
389 lines
8.1 KiB
Go
package melody
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"service/api/base"
|
|
"service/bizcommon/util"
|
|
"service/library/logger"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
pingPeriod = 60 * time.Second
|
|
)
|
|
|
|
// Session wrapper around websocket connections.
|
|
type Session struct {
|
|
Sid string
|
|
Mid int64
|
|
Did string
|
|
DevType int32
|
|
Request *http.Request
|
|
Keys map[string]interface{}
|
|
conn *websocket.Conn
|
|
output chan *CMsg
|
|
outputDone chan struct{}
|
|
melody *Melody
|
|
open bool
|
|
rwmutex *sync.RWMutex
|
|
}
|
|
|
|
func genSessionId(r *base.BaseRequest) string {
|
|
s := fmt.Sprintf("%s_%d_%s", r.Did, r.Mid, time.Now().Format(time.RFC3339))
|
|
return fmt.Sprintf("%x", md5.Sum([]byte(s)))
|
|
}
|
|
|
|
func (s *Session) writeMessage(message *CMsg) {
|
|
if s.closed() {
|
|
s.melody.errorHandler(s, ErrWriteClosed)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case s.output <- message:
|
|
default:
|
|
s.melody.errorHandler(s, ErrMessageBufferFull)
|
|
}
|
|
}
|
|
|
|
func (s *Session) writeRaw(message *CMsg) error {
|
|
if s.closed() {
|
|
return ErrWriteClosed
|
|
}
|
|
logger.Info("_writeRaw, m: %v", util.ToJson(message))
|
|
|
|
//err := s.conn.SetWriteDeadline(time.Now().Add(s.melody.Config.WriteWait))
|
|
//if err != nil {
|
|
// return err
|
|
//}
|
|
//
|
|
//w, err := s.conn.NextWriter(websocket.BinaryMessage)
|
|
//if err != nil {
|
|
// return err
|
|
//}
|
|
//
|
|
//_, err = w.Write(util.MustMarshal(message))
|
|
//if err != nil {
|
|
// return err
|
|
//}
|
|
|
|
s.conn.SetWriteDeadline(time.Now().Add(s.melody.Config.WriteWait))
|
|
err := s.conn.WriteMessage(message.T, message.Msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) closed() bool {
|
|
s.rwmutex.RLock()
|
|
defer s.rwmutex.RUnlock()
|
|
|
|
return !s.open
|
|
}
|
|
|
|
func (s *Session) close() {
|
|
s.rwmutex.Lock()
|
|
open := s.open
|
|
s.open = false
|
|
s.rwmutex.Unlock()
|
|
if open {
|
|
s.conn.Close()
|
|
close(s.outputDone)
|
|
}
|
|
}
|
|
|
|
func (s *Session) ping() {
|
|
s.writeRaw(&CMsg{T: websocket.PingMessage, Msg: []byte{}})
|
|
}
|
|
|
|
func (s *Session) writePump() {
|
|
ticker := time.NewTicker(s.melody.Config.PingPeriod)
|
|
defer ticker.Stop()
|
|
|
|
loop:
|
|
for {
|
|
select {
|
|
case msg := <-s.output:
|
|
err := s.writeRaw(msg)
|
|
|
|
if err != nil {
|
|
logger.Error("writeRaw fail, err: %v", err)
|
|
break loop
|
|
}
|
|
|
|
if msg.T == websocket.CloseMessage {
|
|
break loop
|
|
}
|
|
|
|
//if msg.T == websocket.TextMessage {
|
|
// s.melody.messageSentHandler(s, msg.Msg)
|
|
//}
|
|
//
|
|
//if msg.T == websocket.BinaryMessage {
|
|
// s.melody.messageSentHandlerBinary(s, msg.Msg)
|
|
//}
|
|
case <-ticker.C:
|
|
s.ping()
|
|
case _, ok := <-s.outputDone:
|
|
if !ok {
|
|
break loop
|
|
}
|
|
}
|
|
}
|
|
|
|
s.close()
|
|
}
|
|
|
|
func (s *Session) readPump() {
|
|
s.conn.SetReadLimit(s.melody.Config.MaxMessageSize)
|
|
s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait))
|
|
|
|
s.conn.SetPongHandler(func(string) error {
|
|
s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait))
|
|
s.melody.pongHandler(s)
|
|
return nil
|
|
})
|
|
|
|
if s.melody.closeHandler != nil {
|
|
s.conn.SetCloseHandler(func(code int, text string) error {
|
|
return s.melody.closeHandler(s, code, text)
|
|
})
|
|
}
|
|
|
|
for {
|
|
_, messageData, err := s.conn.ReadMessage()
|
|
if err != nil {
|
|
logger.Error("ReadMessage fail, s: %v, err: %v", s.String(), err)
|
|
break
|
|
}
|
|
logger.Info("ReadMsg: %v", string(messageData))
|
|
|
|
msg := CMsg{}
|
|
err = json.Unmarshal(messageData, &msg)
|
|
|
|
if s.Sid == "" && msg.T != CMsgTypeInit && msg.T != CMsgTypeInitResp {
|
|
logger.Warn("recv msg from uninited conn: %s, msg: %s", s, msg.String())
|
|
continue
|
|
}
|
|
|
|
switch msg.T {
|
|
case CMsgTypeInit, CMsgTypeInitResp:
|
|
err := s.Init()
|
|
if err != nil {
|
|
_ = s.Close()
|
|
logger.Error("failed to init conn: %s - %s", s.String(), err.Error())
|
|
break
|
|
}
|
|
|
|
//case CMsgTypeBiz:
|
|
// if c.BizMsgHandler != nil {
|
|
// err = c.BizMsgHandler(c, &msg)
|
|
// if err != nil {
|
|
// logger.Error("failed to proc msg, conn: %s, msg: %s, err: %s", c.Str(), msg.String(), err.Error())
|
|
// }
|
|
// continue
|
|
// }
|
|
//
|
|
// poster := c.Hub.GetPoster(msg.ChannelId)
|
|
// if poster == nil {
|
|
// logger.Error("channel poster not found: %s", msg.ChannelId)
|
|
// continue
|
|
// }
|
|
// chanmsg = lib.ChanMsg{
|
|
// Id: msg.Id,
|
|
// SessionId: c.Session,
|
|
// Mid: c.Mid,
|
|
// Data: msg.Data,
|
|
// IP: c.IP,
|
|
// }
|
|
// err = poster.SendMsg(chanmsg)
|
|
// if err != nil {
|
|
// logger.Error("failed to Send msg to channel: msg: %s, %s", msg.String(), err.Error())
|
|
// }
|
|
|
|
default:
|
|
// TODO: 处理其他类型消息
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write writes message to session.
|
|
func (s *Session) Write(msg []byte) error {
|
|
if s.closed() {
|
|
return ErrSessionClosed
|
|
}
|
|
|
|
s.writeMessage(&CMsg{T: websocket.TextMessage, Msg: msg})
|
|
|
|
return nil
|
|
}
|
|
|
|
// WriteBinary writes a binary message to session.
|
|
func (s *Session) WriteBinary(msg []byte) error {
|
|
if s.closed() {
|
|
return ErrSessionClosed
|
|
}
|
|
|
|
s.writeMessage(&CMsg{T: websocket.BinaryMessage, Msg: msg})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) WriteBiz(msg []byte) error {
|
|
if s.closed() {
|
|
return ErrSessionClosed
|
|
}
|
|
|
|
s.writeMessage(&CMsg{T: websocket.BinaryMessage, Msg: msg})
|
|
|
|
return nil
|
|
}
|
|
|
|
// WriteRaw writes a raw message to session.
|
|
func (s *Session) WriteRaw(msg *CMsg) error {
|
|
if s.closed() {
|
|
return ErrSessionClosed
|
|
}
|
|
|
|
s.writeMessage(msg)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes session.
|
|
func (s *Session) Close() error {
|
|
if s.closed() {
|
|
return ErrSessionClosed
|
|
}
|
|
|
|
s.writeMessage(&CMsg{T: websocket.CloseMessage, Msg: []byte{}})
|
|
|
|
return nil
|
|
}
|
|
|
|
// CloseWithMsg closes the session with the provided payload.
|
|
// Use the FormatCloseMessage function to format a proper close message payload.
|
|
func (s *Session) CloseWithMsg(msg []byte) error {
|
|
if s.closed() {
|
|
return ErrSessionClosed
|
|
}
|
|
|
|
s.writeMessage(&CMsg{T: websocket.CloseMessage, Msg: msg})
|
|
|
|
return nil
|
|
}
|
|
|
|
// Set is used to store a new key/value pair exclusively for this session.
|
|
// It also lazy initializes s.Keys if it was not used previously.
|
|
func (s *Session) Set(key string, value interface{}) {
|
|
s.rwmutex.Lock()
|
|
defer s.rwmutex.Unlock()
|
|
|
|
if s.Keys == nil {
|
|
s.Keys = make(map[string]interface{})
|
|
}
|
|
|
|
s.Keys[key] = value
|
|
}
|
|
|
|
// Get returns the value for the given key, ie: (value, true).
|
|
// If the value does not exists it returns (nil, false)
|
|
func (s *Session) Get(key string) (value interface{}, exists bool) {
|
|
s.rwmutex.RLock()
|
|
defer s.rwmutex.RUnlock()
|
|
|
|
if s.Keys != nil {
|
|
value, exists = s.Keys[key]
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// MustGet returns the value for the given key if it exists, otherwise it panics.
|
|
func (s *Session) MustGet(key string) interface{} {
|
|
if value, exists := s.Get(key); exists {
|
|
return value
|
|
}
|
|
|
|
panic("Key \"" + key + "\" does not exist")
|
|
}
|
|
|
|
// UnSet will delete the key and has no return value
|
|
func (s *Session) UnSet(key string) {
|
|
s.rwmutex.Lock()
|
|
defer s.rwmutex.Unlock()
|
|
if s.Keys != nil {
|
|
delete(s.Keys, key)
|
|
}
|
|
}
|
|
|
|
// IsClosed returns the status of the connection.
|
|
func (s *Session) IsClosed() bool {
|
|
return s.closed()
|
|
}
|
|
|
|
// LocalAddr returns the local addr of the connection.
|
|
func (s *Session) LocalAddr() net.Addr {
|
|
return s.conn.LocalAddr()
|
|
}
|
|
|
|
// RemoteAddr returns the remote addr of the connection.
|
|
func (s *Session) RemoteAddr() net.Addr {
|
|
return s.conn.RemoteAddr()
|
|
}
|
|
|
|
func (s *Session) Init() error {
|
|
if s.Sid != "" {
|
|
logger.Warn("recv init msg from inited conn: %s", s.String())
|
|
return nil
|
|
}
|
|
|
|
initData := InitMsgData{}
|
|
initData.Sid = genSid(s)
|
|
s.Sid = initData.Sid
|
|
initData.PingInterval = int(pingPeriod / time.Second)
|
|
|
|
msgBs, _ := json.Marshal(initData)
|
|
iMsg := &CMsg{
|
|
Id: GenMsgId(),
|
|
T: CMsgTypeInitResp,
|
|
Msg: json.RawMessage(msgBs),
|
|
}
|
|
// init阶段发送json非压缩数据
|
|
//err := s.conn.WriteMessage(websocket.TextMessage, []byte("abcd"))
|
|
//if err != nil {
|
|
// logger.Error("WriteMessage fail, err: %v", err)
|
|
// return err
|
|
//}
|
|
//err = s.conn.WriteJSON(iMsg)
|
|
//if err != nil {
|
|
// logger.Error("WriteJSON fail, err: %v", err)
|
|
// return err
|
|
//}
|
|
bs, _ := json.Marshal(iMsg)
|
|
err := s.WriteBinary(bs)
|
|
if err != nil {
|
|
logger.Error("WriteBinary fail, err: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) String() string {
|
|
return fmt.Sprintf("Mid: %d, sid: %s", s.Mid, s.Sid)
|
|
}
|
|
|
|
func genSid(s *Session) string {
|
|
str := fmt.Sprintf("%s_%d_%s", s.Did, s.Mid, time.Now().Format(time.RFC3339))
|
|
return fmt.Sprintf("%x", md5.Sum([]byte(str)))[:16]
|
|
}
|