a4a1b63a创建于 2024年10月14日历史提交
// Copyright (C) 2014-2018 Goodrain Co., Ltd.
// RAINBOND, Application Management Platform

// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version. For any non-GPL usage of Rainbond,
// one or multiple Commercial Licenses authorized by Goodrain Co., Ltd.
// must be obtained first.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package util

import (
	"errors"
	"io"
	"net"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"golang.org/x/net/context"

	"github.com/sirupsen/logrus"
)

// Error type
var (
	ErrConnClosing   = errors.New("use of closed network connection")
	ErrWriteBlocking = errors.New("write packet was blocking")
	ErrReadBlocking  = errors.New("read packet was blocking")
)

// Conn exposes a set of callbacks for the various events that occur on a connection
type Conn struct {
	srv               *Server
	conn              *net.TCPConn  // the raw connection
	extraData         interface{}   // to save extra data
	closeOnce         sync.Once     // close the conn, once, per instance
	closeFlag         int32         // close flag
	closeChan         chan struct{} // close chanel
	packetSendChan    chan Packet   // packet send chanel
	packetReceiveChan chan Packet   // packeet receive chanel
	buffer            *Buffer
	ctx               context.Context
	pro               Protocol
	timer             *time.Timer
}

// ConnCallback is an interface of methods that are used as callbacks on a connection
type ConnCallback interface {
	// OnConnect is called when the connection was accepted,
	// If the return value of false is closed
	OnConnect(*Conn) bool

	// OnMessage is called when the connection receives a packet,
	// If the return value of false is closed
	OnMessage(Packet) bool

	// OnClose is called when the connection closed
	OnClose(*Conn)
}

// newConn returns a wrapper of raw conn
func newConn(conn *net.TCPConn, srv *Server, ctx context.Context) *Conn {
	p := &MessageProtocol{}
	p.SetConn(conn)
	conn.SetLinger(3)
	conn.SetReadBuffer(1024 * 1024 * 24)
	return &Conn{
		ctx:               ctx,
		srv:               srv,
		conn:              conn,
		closeChan:         make(chan struct{}),
		packetSendChan:    make(chan Packet, srv.config.PacketSendChanLimit),
		packetReceiveChan: make(chan Packet, srv.config.PacketReceiveChanLimit),
		pro:               p,
	}
}

// GetExtraData gets the extra data from the Conn
func (c *Conn) GetExtraData() interface{} {
	return c.extraData
}

// PutExtraData puts the extra data with the Conn
func (c *Conn) PutExtraData(data interface{}) {
	c.extraData = data
}

// GetRawConn returns the raw net.TCPConn from the Conn
func (c *Conn) GetRawConn() *net.TCPConn {
	return c.conn
}

// Close closes the connection
func (c *Conn) Close() {
	c.closeOnce.Do(func() {
		atomic.StoreInt32(&c.closeFlag, 1)
		close(c.closeChan)
		close(c.packetSendChan)
		close(c.packetReceiveChan)
		c.conn.Close()
		c.srv.callback.OnClose(c)
	})
}

// IsClosed indicates whether or not the connection is closed
func (c *Conn) IsClosed() bool {
	return atomic.LoadInt32(&c.closeFlag) == 1
}

// AsyncWritePacket async writes a packet, this method will never block
func (c *Conn) AsyncWritePacket(p Packet, timeout time.Duration) (err error) {
	if c.IsClosed() {
		return ErrConnClosing
	}

	defer func() {
		if e := recover(); e != nil {
			err = ErrConnClosing
		}
	}()

	if timeout == 0 {
		select {
		case c.packetSendChan <- p:
			return nil

		default:
			return ErrWriteBlocking
		}

	} else {
		select {
		case c.packetSendChan <- p:
			return nil

		case <-c.closeChan:
			return ErrConnClosing

		case <-time.After(timeout):
			return ErrWriteBlocking
		}
	}
}

// Do it
func (c *Conn) Do() {
	if !c.srv.callback.OnConnect(c) {
		return
	}
	asyncDo(c.readLoop, c.srv.waitGroup)
}

var timeOut = time.Second * 15

func (c *Conn) readLoop() {
	defer func() {
		if err := recover(); err != nil {
			logrus.Error(err)
		}
		c.Close()
	}()
	//15秒未接受到消息或ping,则关闭连接
	c.timer = time.NewTimer(timeOut)
	defer c.timer.Stop()
	asyncDo(c.readPing, c.srv.waitGroup)
	for {
		select {
		case <-c.srv.exitChan:
			return
		case <-c.ctx.Done():
			return
		case <-c.closeChan:
			return
		default:
		}
		p, err := c.pro.ReadPacket()
		if err == io.EOF {
			return
		}
		if err == io.ErrUnexpectedEOF {
			return
		}
		if err == errClosed {
			return
		}
		if err == io.ErrNoProgress {
			return
		}
		if err != nil {
			if strings.HasSuffix(err.Error(), "use of closed network connection") {
				logrus.Error("use of closed network connection")
				return
			}
			logrus.Error("read package error:", err.Error())
			return
		}
		if p.IsNull() {
			return
		}
		if p.IsPing() {
			if ok := c.timer.Reset(timeOut); !ok {
				c.timer = time.NewTimer(timeOut)
			}
			continue
		}
		if ok := c.srv.callback.OnMessage(p); !ok {
			continue
		}
		if ok := c.timer.Reset(timeOut); !ok {
			c.timer = time.NewTimer(timeOut)
		}
	}
}

func (c *Conn) readPing() {
	for {
		select {
		case <-c.srv.exitChan:
			return
		case <-c.ctx.Done():
			return
		case <-c.closeChan:
			return
		case <-c.timer.C:
			logrus.Debug("can not receive message more than 15s.close the con")
			c.conn.Close()
			return

		}
	}
}

func asyncDo(fn func(), wg *sync.WaitGroup) {
	//wg.Add(1)
	go func() {
		fn()
		//wg.Done()
	}()
}