利用 Go 语言编写一个简单的 WebSocket 推送服务

2020-01-28 13:04:04刘景俊

首先将传入的 http.Request 转换为 websocket.Conn,再将其分装为我们自定义的一个 wserver.Conn(封装,或者说是组合,是 Go 语言的典型用法。记住,Go 语言没有继承,只有组合)。然后设置了 Conn 的 AfterReadFunc 和 BeforeCloseFunc 方法,接着启动了 conn.Listen()。AfterReadFunc 意思是当 Conn 读取到数据后,尝试验证并根据 token 计算 userID,然乎 bind 注册绑定。BeforeCloseFunc 则为 Conn 关闭前进行解绑操作。

pushHandler

pushHandler 则容易理解。它解析请求然后推送数据:


// Authorize if needed. Then decode the request and push message to each
// realted websocket connection.
func (s *pushHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 if r.Method != http.MethodPost {
  w.WriteHeader(http.StatusMethodNotAllowed)
  return
 }
 // authorize
 if s.authFunc != nil {
  if ok := s.authFunc(r); !ok {
   w.WriteHeader(http.StatusUnauthorized)
   return
  }
 }
 // read request
 var pm PushMessage
 decoder := json.NewDecoder(r.Body)
 if err := decoder.Decode(&pm); err != nil {
  w.WriteHeader(http.StatusBadRequest)
  w.Write([]byte(ErrRequestIllegal.Error()))
  return
 }
 // validate the data
 if pm.UserID == "" || pm.Event == "" || pm.Message == "" {
  w.WriteHeader(http.StatusBadRequest)
  w.Write([]byte(ErrRequestIllegal.Error()))
  return
 }
 cnt, err := s.push(pm.UserID, pm.Event, pm.Message)
 if err != nil {
  w.WriteHeader(http.StatusInternalServerError)
  w.Write([]byte(err.Error()))
  return
 }
 result := strings.NewReader(fmt.Sprintf("message sent to %d clients", cnt))
 io.Copy(w, result)
}

Conn

Conn (此处指 wserver.Conn) 为 websocket.Conn 的包装。


// Conn wraps websocket.Conn with Conn. It defines to listen and read
// data from Conn.
type Conn struct {
 Conn *websocket.Conn
 AfterReadFunc func(messageType int, r io.Reader)
 BeforeCloseFunc func()

 once sync.Once
 id  string
 stopCh chan struct{}
}

最主要的方法为 Listen():


// Listen listens for receive data from websocket connection. It blocks
// until websocket connection is closed.
func (c *Conn) Listen() {
 c.Conn.SetCloseHandler(func(code int, text string) error {
  if c.BeforeCloseFunc != nil {
   c.BeforeCloseFunc()
  }
  if err := c.Close(); err != nil {
   log.Println(err)
  }
  message := websocket.FormatCloseMessage(code, "")
  c.Conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
  return nil
 })
 // Keeps reading from Conn util get error.
ReadLoop:
 for {
  select {
  case <-c.stopCh:
   break ReadLoop
  default:
   messageType, r, err := c.Conn.NextReader()
   if err != nil {
    // TODO: handle read error maybe
    break ReadLoop
   }
   if c.AfterReadFunc != nil {
    c.AfterReadFunc(messageType, r)
   }
  }
 }
}

主要设置了当 websocket 连接关闭时的处理和不停地读取数据。

文中很难全面地描述整个代码的运作流程,像具体阅读代码,请前往 github.com/alfred-zhong/wserver 获取。