package util
import (
"errors"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/context"
"github.com/sirupsen/logrus"
)
var (
ErrConnClosing = errors.New("use of closed network connection")
ErrWriteBlocking = errors.New("write packet was blocking")
ErrReadBlocking = errors.New("read packet was blocking")
)
type Conn struct {
srv *Server
conn *net.TCPConn
extraData interface{}
closeOnce sync.Once
closeFlag int32
closeChan chan struct{}
packetSendChan chan Packet
packetReceiveChan chan Packet
buffer *Buffer
ctx context.Context
pro Protocol
timer *time.Timer
}
type ConnCallback interface {
OnConnect(*Conn) bool
OnMessage(Packet) bool
OnClose(*Conn)
}
func newConn(conn *net.TCPConn, srv *Server, ctx context.Context) *Conn {
p := &MessageProtocol{}
p.SetConn(conn)
conn.SetLinger(3)
conn.SetReadBuffer(1024 * 1024 * 24)
return &Conn{
ctx: ctx,
srv: srv,
conn: conn,
closeChan: make(chan struct{}),
packetSendChan: make(chan Packet, srv.config.PacketSendChanLimit),
packetReceiveChan: make(chan Packet, srv.config.PacketReceiveChanLimit),
pro: p,
}
}
func (c *Conn) GetExtraData() interface{} {
return c.extraData
}
func (c *Conn) PutExtraData(data interface{}) {
c.extraData = data
}
func (c *Conn) GetRawConn() *net.TCPConn {
return c.conn
}
func (c *Conn) Close() {
c.closeOnce.Do(func() {
atomic.StoreInt32(&c.closeFlag, 1)
close(c.closeChan)
close(c.packetSendChan)
close(c.packetReceiveChan)
c.conn.Close()
c.srv.callback.OnClose(c)
})
}
func (c *Conn) IsClosed() bool {
return atomic.LoadInt32(&c.closeFlag) == 1
}
func (c *Conn) AsyncWritePacket(p Packet, timeout time.Duration) (err error) {
if c.IsClosed() {
return ErrConnClosing
}
defer func() {
if e := recover(); e != nil {
err = ErrConnClosing
}
}()
if timeout == 0 {
select {
case c.packetSendChan <- p:
return nil
default:
return ErrWriteBlocking
}
} else {
select {
case c.packetSendChan <- p:
return nil
case <-c.closeChan:
return ErrConnClosing
case <-time.After(timeout):
return ErrWriteBlocking
}
}
}
func (c *Conn) Do() {
if !c.srv.callback.OnConnect(c) {
return
}
asyncDo(c.readLoop, c.srv.waitGroup)
}
var timeOut = time.Second * 15
func (c *Conn) readLoop() {
defer func() {
if err := recover(); err != nil {
logrus.Error(err)
}
c.Close()
}()
c.timer = time.NewTimer(timeOut)
defer c.timer.Stop()
asyncDo(c.readPing, c.srv.waitGroup)
for {
select {
case <-c.srv.exitChan:
return
case <-c.ctx.Done():
return
case <-c.closeChan:
return
default:
}
p, err := c.pro.ReadPacket()
if err == io.EOF {
return
}
if err == io.ErrUnexpectedEOF {
return
}
if err == errClosed {
return
}
if err == io.ErrNoProgress {
return
}
if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
logrus.Error("use of closed network connection")
return
}
logrus.Error("read package error:", err.Error())
return
}
if p.IsNull() {
return
}
if p.IsPing() {
if ok := c.timer.Reset(timeOut); !ok {
c.timer = time.NewTimer(timeOut)
}
continue
}
if ok := c.srv.callback.OnMessage(p); !ok {
continue
}
if ok := c.timer.Reset(timeOut); !ok {
c.timer = time.NewTimer(timeOut)
}
}
}
func (c *Conn) readPing() {
for {
select {
case <-c.srv.exitChan:
return
case <-c.ctx.Done():
return
case <-c.closeChan:
return
case <-c.timer.C:
logrus.Debug("can not receive message more than 15s.close the con")
c.conn.Close()
return
}
}
}
func asyncDo(fn func(), wg *sync.WaitGroup) {
go func() {
fn()
}()
}