package balancer
import (
"fmt"
"sync"
"time"
"github.com/bestruirui/octopus/internal/model"
"github.com/bestruirui/octopus/internal/op"
"github.com/bestruirui/octopus/internal/utils/log"
)
type CircuitState int
const (
StateClosed CircuitState = iota
StateOpen
StateHalfOpen
)
type circuitEntry struct {
State CircuitState
ConsecutiveFailures int64
LastFailureTime time.Time
TripCount int
mu sync.Mutex
}
var globalBreaker sync.Map
func circuitKey(channelID, keyID int, modelName string) string {
return fmt.Sprintf("%d:%d:%s", channelID, keyID, modelName)
}
func getOrCreateEntry(key string) *circuitEntry {
if v, ok := globalBreaker.Load(key); ok {
return v.(*circuitEntry)
}
entry := &circuitEntry{State: StateClosed}
actual, _ := globalBreaker.LoadOrStore(key, entry)
return actual.(*circuitEntry)
}
func getThreshold() int64 {
v, err := op.SettingGetInt(model.SettingKeyCircuitBreakerThreshold)
if err != nil || v <= 0 {
return 5
}
return int64(v)
}
func GetCooldown(tripCount int) time.Duration {
base, err := op.SettingGetInt(model.SettingKeyCircuitBreakerCooldown)
if err != nil || base <= 0 {
base = 60
}
maxCooldown, err := op.SettingGetInt(model.SettingKeyCircuitBreakerMaxCooldown)
if err != nil || maxCooldown <= 0 {
maxCooldown = 600
}
cooldown := base
if tripCount > 1 {
shift := tripCount - 1
if shift > 20 {
shift = 20
}
cooldown = base << shift
}
if cooldown > maxCooldown {
cooldown = maxCooldown
}
return time.Duration(cooldown) * time.Second
}
func IsTripped(channelID, keyID int, modelName string) (tripped bool, remaining time.Duration) {
key := circuitKey(channelID, keyID, modelName)
v, ok := globalBreaker.Load(key)
if !ok {
return false, 0
}
entry := v.(*circuitEntry)
entry.mu.Lock()
defer entry.mu.Unlock()
switch entry.State {
case StateClosed:
return false, 0
case StateOpen:
cooldown := GetCooldown(entry.TripCount)
elapsed := time.Since(entry.LastFailureTime)
if elapsed >= cooldown {
entry.State = StateHalfOpen
log.Infof("circuit breaker [%s] Open -> HalfOpen (cooldown %v elapsed)", key, cooldown)
return false, 0
}
return true, cooldown - elapsed
case StateHalfOpen:
return true, 0
default:
return false, 0
}
}
func RecordSuccess(channelID, keyID int, modelName string) {
key := circuitKey(channelID, keyID, modelName)
v, ok := globalBreaker.Load(key)
if !ok {
return
}
entry := v.(*circuitEntry)
entry.mu.Lock()
defer entry.mu.Unlock()
if entry.State == StateHalfOpen {
log.Infof("circuit breaker [%s] HalfOpen -> Closed (probe succeeded)", key)
}
entry.State = StateClosed
entry.ConsecutiveFailures = 0
entry.TripCount = 0
}
func RecordFailure(channelID, keyID int, modelName string) {
key := circuitKey(channelID, keyID, modelName)
entry := getOrCreateEntry(key)
entry.mu.Lock()
defer entry.mu.Unlock()
entry.LastFailureTime = time.Now()
switch entry.State {
case StateClosed:
entry.ConsecutiveFailures++
threshold := getThreshold()
if entry.ConsecutiveFailures >= threshold {
entry.State = StateOpen
entry.TripCount++
log.Warnf("circuit breaker [%s] Closed -> Open (failures=%d >= threshold=%d, tripCount=%d, cooldown=%v)",
key, entry.ConsecutiveFailures, threshold, entry.TripCount, GetCooldown(entry.TripCount))
}
case StateHalfOpen:
entry.State = StateOpen
entry.TripCount++
entry.ConsecutiveFailures = 0
log.Warnf("circuit breaker [%s] HalfOpen -> Open (probe failed, tripCount=%d, cooldown=%v)",
key, entry.TripCount, GetCooldown(entry.TripCount))
case StateOpen:
}
}