package proxy
import (
"net/http"
"strings"
"sync/atomic"
"github.com/sirupsen/logrus"
)
type ContextKey string
type RoundRobin struct {
ops *uint64
}
type LoadBalance interface {
Select(r *http.Request, endpoints EndpointList) Endpoint
}
type Endpoint string
func (e Endpoint) String() string {
return string(e)
}
func (e Endpoint) GetName() string {
if kv := strings.Split(string(e), "=>"); len(kv) > 1 {
return kv[0]
}
return string(e)
}
func (e Endpoint) GetAddr() string {
if kv := strings.Split(string(e), "=>"); len(kv) > 1 {
return kv[1]
}
return string(e)
}
func (e Endpoint) GetHTTPAddr() string {
if kv := strings.Split(string(e), "=>"); len(kv) > 1 {
return withScheme(kv[1])
}
return withScheme(string(e))
}
func withScheme(s string) string {
if strings.HasPrefix(s, "http") {
return s
}
return "http://" + s
}
type EndpointList []Endpoint
func (e *EndpointList) Len() int {
return len(*e)
}
func (e *EndpointList) Add(endpoints ...string) {
for _, end := range endpoints {
*e = append(*e, Endpoint(end))
}
}
func (e *EndpointList) Delete(endpoints ...string) {
var new EndpointList
for _, endpoint := range endpoints {
for _, old := range *e {
if string(old) != endpoint {
new = append(new, old)
}
}
}
*e = new
}
func (e *EndpointList) Selec(i int) Endpoint {
return (*e)[i]
}
func (e *EndpointList) HaveEndpoint(endpoint string) bool {
for _, en := range *e {
if en.String() == endpoint {
return true
}
}
return false
}
func CreateEndpoints(endpoints []string) EndpointList {
var epl EndpointList
for _, e := range endpoints {
epl = append(epl, Endpoint(e))
}
return epl
}
func NewRoundRobin() LoadBalance {
var ops uint64
ops = 0
return RoundRobin{
ops: &ops,
}
}
func (rr RoundRobin) Select(r *http.Request, endpoints EndpointList) Endpoint {
l := uint64(endpoints.Len())
if 0 >= l {
return ""
}
selec := int(atomic.AddUint64(rr.ops, 1) % l)
return endpoints.Selec(selec)
}
type SelectBalance struct {
hostIDMap map[string]string
}
func NewSelectBalance() *SelectBalance {
return &SelectBalance{
hostIDMap: map[string]string{"local": "rbd-eventlog:6363"},
}
}
func (s *SelectBalance) Select(r *http.Request, endpoints EndpointList) Endpoint {
if r.URL == nil {
return Endpoint(s.hostIDMap["local"])
}
id2ip := map[string]string{"local": "rbd-eventlog:6363"}
for _, end := range endpoints {
if kv := strings.Split(string(end), "=>"); len(kv) > 1 {
id2ip[kv[0]] = kv[1]
}
}
if r.URL != nil {
hostID := r.URL.Query().Get("host_id")
if hostID == "" {
hostIDFromContext := r.Context().Value(ContextKey("host_id"))
if hostIDFromContext != nil {
hostID = hostIDFromContext.(string)
}
}
if e, ok := id2ip[hostID]; ok {
logrus.Infof("[lb selelct] find host %s from name %s success", e, hostID)
return Endpoint(e)
}
}
if len(endpoints) > 0 {
logrus.Infof("default endpoint is %s", endpoints[len(endpoints)-1])
return endpoints[len(endpoints)-1]
}
return Endpoint(s.hostIDMap["local"])
}