package relay
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"maps"
"net/http"
"slices"
"strings"
"time"
"github.com/bestruirui/octopus/internal/helper"
dbmodel "github.com/bestruirui/octopus/internal/model"
"github.com/bestruirui/octopus/internal/op"
"github.com/bestruirui/octopus/internal/relay/balancer"
"github.com/bestruirui/octopus/internal/server/resp"
"github.com/bestruirui/octopus/internal/transformer/inbound"
"github.com/bestruirui/octopus/internal/transformer/model"
"github.com/bestruirui/octopus/internal/transformer/outbound"
"github.com/bestruirui/octopus/internal/utils/log"
"github.com/gin-gonic/gin"
"github.com/tmaxmax/go-sse"
)
func Handler(inboundType inbound.InboundType, c *gin.Context) {
internalRequest, inAdapter, err := parseRequest(inboundType, c)
if err != nil {
return
}
supportedModels := c.GetString("supported_models")
if supportedModels != "" {
supportedModelsArray := strings.Split(supportedModels, ",")
if !slices.Contains(supportedModelsArray, internalRequest.Model) {
resp.Error(c, http.StatusBadRequest, "model not supported")
return
}
}
requestModel := internalRequest.Model
apiKeyID := c.GetInt("api_key_id")
group, err := op.GroupGetEnabledMap(requestModel, c.Request.Context())
if err != nil {
resp.Error(c, http.StatusNotFound, "model not found")
return
}
iter := balancer.NewIterator(group, apiKeyID, requestModel)
if iter.Len() == 0 {
resp.Error(c, http.StatusServiceUnavailable, "no available channel")
return
}
metrics := NewRelayMetrics(apiKeyID, requestModel, internalRequest)
req := &relayRequest{
c: c,
inAdapter: inAdapter,
internalRequest: internalRequest,
metrics: metrics,
apiKeyID: apiKeyID,
requestModel: requestModel,
iter: iter,
}
var lastErr error
for iter.Next() {
select {
case <-c.Request.Context().Done():
log.Infof("request context canceled, stopping retry")
metrics.Save(c.Request.Context(), false, context.Canceled, iter.Attempts())
return
default:
}
item := iter.Item()
channel, err := op.ChannelGet(item.ChannelID, c.Request.Context())
if err != nil {
log.Warnf("failed to get channel %d: %v", item.ChannelID, err)
iter.Skip(item.ChannelID, 0, fmt.Sprintf("channel_%d", item.ChannelID), fmt.Sprintf("channel not found: %v", err))
lastErr = err
continue
}
if !channel.Enabled {
iter.Skip(channel.ID, 0, channel.Name, "channel disabled")
continue
}
usedKey := channel.GetChannelKey()
if usedKey.ChannelKey == "" {
iter.Skip(channel.ID, 0, channel.Name, "no available key")
continue
}
if iter.SkipCircuitBreak(channel.ID, usedKey.ID, channel.Name) {
continue
}
outAdapter := outbound.Get(channel.Type)
if outAdapter == nil {
iter.Skip(channel.ID, usedKey.ID, channel.Name, fmt.Sprintf("unsupported channel type: %d", channel.Type))
continue
}
if internalRequest.IsEmbeddingRequest() && !outbound.IsEmbeddingChannelType(channel.Type) {
iter.Skip(channel.ID, usedKey.ID, channel.Name, "channel type not compatible with embedding request")
continue
}
if internalRequest.IsChatRequest() && !outbound.IsChatChannelType(channel.Type) {
iter.Skip(channel.ID, usedKey.ID, channel.Name, "channel type not compatible with chat request")
continue
}
internalRequest.Model = item.ModelName
log.Infof("request model %s, mode: %d, forwarding to channel: %s model: %s (attempt %d/%d, sticky=%t)",
requestModel, group.Mode, channel.Name, item.ModelName,
iter.Index()+1, iter.Len(), iter.IsSticky())
ra := &relayAttempt{
relayRequest: req,
outAdapter: outAdapter,
channel: channel,
usedKey: usedKey,
firstTokenTimeOutSec: group.FirstTokenTimeOut,
}
result := ra.attempt()
if result.Success {
metrics.Save(c.Request.Context(), true, nil, iter.Attempts())
return
}
if result.Written {
metrics.Save(c.Request.Context(), false, result.Err, iter.Attempts())
return
}
lastErr = result.Err
}
metrics.Save(c.Request.Context(), false, lastErr, iter.Attempts())
resp.Error(c, http.StatusBadGateway, "all channels failed")
}
func (ra *relayAttempt) attempt() attemptResult {
span := ra.iter.StartAttempt(ra.channel.ID, ra.usedKey.ID, ra.channel.Name)
statusCode, fwdErr := ra.forward()
ra.usedKey.StatusCode = statusCode
ra.usedKey.LastUseTimeStamp = time.Now().Unix()
if fwdErr == nil {
ra.collectResponse()
ra.usedKey.TotalCost += ra.metrics.Stats.InputCost + ra.metrics.Stats.OutputCost
op.ChannelKeyUpdate(ra.usedKey)
span.End(dbmodel.AttemptSuccess, statusCode, "")
op.StatsChannelUpdate(ra.channel.ID, dbmodel.StatsMetrics{
WaitTime: span.Duration().Milliseconds(),
RequestSuccess: 1,
})
balancer.RecordSuccess(ra.channel.ID, ra.usedKey.ID, ra.internalRequest.Model)
balancer.SetSticky(ra.apiKeyID, ra.requestModel, ra.channel.ID, ra.usedKey.ID)
ra.metrics.ParamOverride = paramOverrideValue(ra.channel.ParamOverride)
return attemptResult{Success: true}
}
op.ChannelKeyUpdate(ra.usedKey)
span.End(dbmodel.AttemptFailed, statusCode, fwdErr.Error())
op.StatsChannelUpdate(ra.channel.ID, dbmodel.StatsMetrics{
WaitTime: span.Duration().Milliseconds(),
RequestFailed: 1,
})
balancer.RecordFailure(ra.channel.ID, ra.usedKey.ID, ra.internalRequest.Model)
ra.metrics.ParamOverride = paramOverrideValue(ra.channel.ParamOverride)
written := ra.c.Writer.Written()
if written {
ra.collectResponse()
}
return attemptResult{
Success: false,
Written: written,
Err: fmt.Errorf("channel %s failed: %v", ra.channel.Name, fwdErr),
}
}
func parseRequest(inboundType inbound.InboundType, c *gin.Context) (*model.InternalLLMRequest, model.Inbound, error) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
resp.Error(c, http.StatusInternalServerError, err.Error())
return nil, nil, err
}
inAdapter := inbound.Get(inboundType)
internalRequest, err := inAdapter.TransformRequest(c.Request.Context(), body)
if err != nil {
resp.Error(c, http.StatusInternalServerError, err.Error())
return nil, nil, err
}
internalRequest.Query = c.Request.URL.Query()
if err := internalRequest.Validate(); err != nil {
resp.Error(c, http.StatusBadRequest, err.Error())
return nil, nil, err
}
return internalRequest, inAdapter, nil
}
func (ra *relayAttempt) forward() (int, error) {
ctx := ra.c.Request.Context()
outboundRequest, err := ra.outAdapter.TransformRequest(
ctx,
ra.internalRequest,
ra.channel.GetBaseUrl(),
ra.usedKey.ChannelKey,
)
if err != nil {
log.Warnf("failed to create request: %v", err)
return 0, fmt.Errorf("failed to create request: %w", err)
}
if ra.channel.ParamOverride != nil && *ra.channel.ParamOverride != "" {
body, err := io.ReadAll(outboundRequest.Body)
if err != nil {
return 0, fmt.Errorf("failed to read body: %w", err)
}
var bodyMap map[string]any
if err := json.Unmarshal(body, &bodyMap); err != nil {
log.Warnf("failed to unmarshal request body: %v, skipping param_override", err)
outboundRequest.Body = io.NopCloser(bytes.NewBuffer(body))
return 0, nil
}
var override map[string]any
if err := json.Unmarshal([]byte(*ra.channel.ParamOverride), &override); err != nil {
log.Warnf("failed to unmarshal param_override: %v, skipping", err)
outboundRequest.Body = io.NopCloser(bytes.NewBuffer(body))
return 0, nil
}
maps.Copy(bodyMap, override)
modifiedBody, err := json.Marshal(bodyMap)
if err != nil {
log.Warnf("failed to marshal modified body: %v, skipping param_override", err)
outboundRequest.Body = io.NopCloser(bytes.NewBuffer(body))
return 0, nil
}
outboundRequest.Body = io.NopCloser(bytes.NewBuffer(modifiedBody))
outboundRequest.ContentLength = int64(len(modifiedBody))
}
ra.copyHeaders(outboundRequest)
response, err := ra.sendRequest(outboundRequest)
if err != nil {
return 0, fmt.Errorf("failed to send request: %w", err)
}
defer response.Body.Close()
if response.StatusCode < 200 || response.StatusCode >= 300 {
body, err := io.ReadAll(response.Body)
if err != nil {
return 0, fmt.Errorf("failed to read response body: %w", err)
}
return 0, fmt.Errorf("upstream error: %d: %s", response.StatusCode, string(body))
}
if ra.internalRequest.Stream != nil && *ra.internalRequest.Stream {
if err := ra.handleStreamResponse(ctx, response); err != nil {
return 0, err
}
return response.StatusCode, nil
}
if err := ra.handleResponse(ctx, response); err != nil {
return 0, err
}
return response.StatusCode, nil
}
func (ra *relayAttempt) copyHeaders(outboundRequest *http.Request) {
for key, values := range ra.c.Request.Header {
if hopByHopHeaders[strings.ToLower(key)] {
continue
}
for _, value := range values {
outboundRequest.Header.Set(key, value)
}
}
if len(ra.channel.CustomHeader) > 0 {
for _, header := range ra.channel.CustomHeader {
outboundRequest.Header.Set(header.HeaderKey, header.HeaderValue)
}
}
}
func (ra *relayAttempt) sendRequest(req *http.Request) (*http.Response, error) {
httpClient, err := helper.ChannelHttpClient(ra.channel)
if err != nil {
log.Warnf("failed to get http client: %v", err)
return nil, err
}
response, err := httpClient.Do(req)
if err != nil {
log.Warnf("failed to send request: %v", err)
return nil, err
}
return response, nil
}
func (ra *relayAttempt) handleStreamResponse(ctx context.Context, response *http.Response) error {
if ct := response.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "text/event-stream") {
body, _ := io.ReadAll(io.LimitReader(response.Body, 16*1024))
return fmt.Errorf("upstream returned non-SSE content-type %q for stream request: %s", ct, string(body))
}
ra.c.Header("Content-Type", "text/event-stream")
ra.c.Header("Cache-Control", "no-cache")
ra.c.Header("Connection", "keep-alive")
ra.c.Header("X-Accel-Buffering", "no")
firstToken := true
type sseReadResult struct {
data string
err error
}
results := make(chan sseReadResult, 1)
go func() {
defer close(results)
readCfg := &sse.ReadConfig{MaxEventSize: maxSSEEventSize}
for ev, err := range sse.Read(response.Body, readCfg) {
if err != nil {
results <- sseReadResult{err: err}
return
}
results <- sseReadResult{data: ev.Data}
}
}()
var firstTokenTimer *time.Timer
var firstTokenC <-chan time.Time
if firstToken && ra.firstTokenTimeOutSec > 0 {
firstTokenTimer = time.NewTimer(time.Duration(ra.firstTokenTimeOutSec) * time.Second)
firstTokenC = firstTokenTimer.C
defer func() {
if firstTokenTimer != nil {
firstTokenTimer.Stop()
}
}()
}
for {
select {
case <-ctx.Done():
log.Infof("client disconnected, stopping stream")
return nil
case <-firstTokenC:
log.Warnf("first token timeout (%ds), switching channel", ra.firstTokenTimeOutSec)
_ = response.Body.Close()
return fmt.Errorf("first token timeout (%ds)", ra.firstTokenTimeOutSec)
case r, ok := <-results:
if !ok {
log.Infof("stream end")
return nil
}
if r.err != nil {
log.Warnf("failed to read event: %v", r.err)
return fmt.Errorf("failed to read stream event: %w", r.err)
}
data, err := ra.transformStreamData(ctx, r.data)
if err != nil || len(data) == 0 {
continue
}
if firstToken {
ra.metrics.SetFirstTokenTime(time.Now())
firstToken = false
if firstTokenTimer != nil {
if !firstTokenTimer.Stop() {
select {
case <-firstTokenTimer.C:
default:
}
}
firstTokenTimer = nil
firstTokenC = nil
}
}
ra.c.Writer.Write(data)
ra.c.Writer.Flush()
}
}
}
func (ra *relayAttempt) transformStreamData(ctx context.Context, data string) ([]byte, error) {
internalStream, err := ra.outAdapter.TransformStream(ctx, []byte(data))
if err != nil {
log.Warnf("failed to transform stream: %v", err)
return nil, err
}
if internalStream == nil {
return nil, nil
}
inStream, err := ra.inAdapter.TransformStream(ctx, internalStream)
if err != nil {
log.Warnf("failed to transform stream: %v", err)
return nil, err
}
return inStream, nil
}
func (ra *relayAttempt) handleResponse(ctx context.Context, response *http.Response) error {
internalResponse, err := ra.outAdapter.TransformResponse(ctx, response)
if err != nil {
log.Warnf("failed to transform response: %v", err)
return fmt.Errorf("failed to transform outbound response: %w", err)
}
inResponse, err := ra.inAdapter.TransformResponse(ctx, internalResponse)
if err != nil {
log.Warnf("failed to transform response: %v", err)
return fmt.Errorf("failed to transform inbound response: %w", err)
}
ra.c.Data(http.StatusOK, "application/json", inResponse)
return nil
}
func (ra *relayAttempt) collectResponse() {
internalResponse, err := ra.inAdapter.GetInternalResponse(ra.c.Request.Context())
if err != nil || internalResponse == nil {
return
}
ra.metrics.SetInternalResponse(internalResponse, ra.internalRequest.Model)
}
func paramOverrideValue(ptr *string) string {
if ptr == nil || *ptr == "" {
return ""
}
return *ptr
}