package proxy
import (
"context"
"fmt"
"github.com/goodrain/rainbond/pkg/gogo"
"log"
"net"
"net/http"
"net/url"
"strings"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
type WebSocketProxy struct {
name string
endpoints EndpointList
lb LoadBalance
upgrader *websocket.Upgrader
}
func (h *WebSocketProxy) Proxy(w http.ResponseWriter, req *http.Request) {
endpoint := h.lb.Select(req, h.endpoints)
path := req.RequestURI
if strings.Contains(path, "?") {
path = path[:strings.Index(path, "?")]
}
u := url.URL{Scheme: "ws", Host: endpoint.GetAddr(), Path: path}
requestHeader := http.Header{}
if origin := req.Header.Get("Origin"); origin != "" {
requestHeader.Add("Origin", origin)
}
for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
requestHeader.Add("Sec-WebSocket-Protocol", prot)
}
for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] {
requestHeader.Add("Cookie", cookie)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
requestHeader.Set("X-Forwarded-For", clientIP)
}
connBackend, resp, err := websocket.DefaultDialer.Dial(u.String(), requestHeader)
if err != nil {
log.Printf("websocketproxy: couldn't dial to remote backend url %s\n", err)
return
}
defer connBackend.Close()
upgradeHeader := http.Header{}
if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" {
upgradeHeader.Set("Sec-Websocket-Protocol", hdr)
}
if hdr := resp.Header.Get("Set-Cookie"); hdr != "" {
upgradeHeader.Set("Set-Cookie", hdr)
}
if h.upgrader == nil {
h.upgrader = &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
EnableCompression: true,
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
w.WriteHeader(500)
},
CheckOrigin: func(r *http.Request) bool {
return true
},
}
}
connPub, err := h.upgrader.Upgrade(w, req, upgradeHeader)
if err != nil {
log.Printf("websocketproxy: couldn't upgrade %s\n", err)
return
}
defer connPub.Close()
errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}
}
_ = gogo.Go(func(ctx context.Context) error {
replicateWebsocketConn(connPub, connBackend, errClient)
return nil
})
_ = gogo.Go(func(ctx context.Context) error {
replicateWebsocketConn(connBackend, connPub, errBackend)
return nil
})
var message string
select {
case err = <-errClient:
message = "websocketproxy: Error when copying from backend to client: %v"
case err = <-errBackend:
message = "websocketproxy: Error when copying from client to backend: %v"
}
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
logrus.Errorf(message, err)
}
}
func (h *WebSocketProxy) UpdateEndpoints(endpoints ...string) {
h.endpoints = CreateEndpoints(endpoints)
}
func (h *WebSocketProxy) Do(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("do not support")
}
func createWebSocketProxy(name string, endpoints []string) *WebSocketProxy {
if name != "dockerlog" {
return &WebSocketProxy{
name: name,
endpoints: CreateEndpoints(endpoints),
lb: NewRoundRobin(),
}
}
return &WebSocketProxy{
name: name,
endpoints: CreateEndpoints(endpoints),
lb: NewSelectBalance(),
}
}