package helper
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/bestruirui/octopus/internal/model"
"github.com/bestruirui/octopus/internal/transformer/outbound"
"github.com/dlclark/regexp2"
)
func FetchModels(ctx context.Context, request model.Channel) ([]string, error) {
client, err := ChannelHttpClient(&request)
if err != nil {
return nil, err
}
fetchModel := make([]string, 0)
switch request.Type {
case outbound.OutboundTypeAnthropic:
fetchModel, err = fetchAnthropicModels(client, ctx, request)
case outbound.OutboundTypeGemini:
fetchModel, err = fetchGeminiModels(client, ctx, request)
default:
fetchModel, err = fetchOpenAIModels(client, ctx, request)
}
if err != nil {
return nil, err
}
if request.MatchRegex != nil && *request.MatchRegex != "" {
matchModel := make([]string, 0)
re, err := regexp2.Compile(*request.MatchRegex, regexp2.ECMAScript)
if err != nil {
return nil, err
}
for _, model := range fetchModel {
matched, err := re.MatchString(model)
if err != nil {
return nil, err
}
if matched {
matchModel = append(matchModel, model)
}
}
return matchModel, nil
}
return fetchModel, nil
}
func fetchOpenAIModels(client *http.Client, ctx context.Context, request model.Channel) ([]string, error) {
req, _ := http.NewRequestWithContext(
ctx,
http.MethodGet,
request.GetBaseUrl()+"/models",
nil,
)
req.Header.Set("Authorization", "Bearer "+request.GetChannelKey().ChannelKey)
for _, header := range request.CustomHeader {
if header.HeaderKey != "" {
req.Header.Set(header.HeaderKey, header.HeaderValue)
}
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result model.OpenAIModelList
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
models := make([]string, 0, len(result.Data))
for _, m := range result.Data {
models = append(models, m.ID)
}
return models, nil
}
func fetchGeminiModels(client *http.Client, ctx context.Context, request model.Channel) ([]string, error) {
var allModels []string
pageToken := ""
for {
req, _ := http.NewRequestWithContext(
ctx,
http.MethodGet,
request.GetBaseUrl()+"/models",
nil,
)
req.Header.Set("X-Goog-Api-Key", request.GetChannelKey().ChannelKey)
for _, header := range request.CustomHeader {
if header.HeaderKey != "" {
req.Header.Set(header.HeaderKey, header.HeaderValue)
}
}
if pageToken != "" {
q := req.URL.Query()
q.Add("pageToken", pageToken)
req.URL.RawQuery = q.Encode()
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result model.GeminiModelList
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
for _, m := range result.Models {
name := strings.TrimPrefix(m.Name, "models/")
allModels = append(allModels, name)
}
if result.NextPageToken == "" {
break
}
pageToken = result.NextPageToken
}
if len(allModels) == 0 {
return fetchOpenAIModels(client, ctx, request)
}
return allModels, nil
}
func fetchAnthropicModels(client *http.Client, ctx context.Context, request model.Channel) ([]string, error) {
var allModels []string
var afterID string
for {
req, _ := http.NewRequestWithContext(
ctx,
http.MethodGet,
request.GetBaseUrl()+"/models",
nil,
)
req.Header.Set("X-Api-Key", request.GetChannelKey().ChannelKey)
req.Header.Set("Anthropic-Version", "2023-06-01")
for _, header := range request.CustomHeader {
if header.HeaderKey != "" {
req.Header.Set(header.HeaderKey, header.HeaderValue)
}
}
q := req.URL.Query()
if afterID != "" {
q.Set("after_id", afterID)
}
req.URL.RawQuery = q.Encode()
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result model.AnthropicModelList
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
for _, m := range result.Data {
allModels = append(allModels, m.ID)
}
if !result.HasMore {
break
}
afterID = result.LastID
}
if len(allModels) == 0 {
return fetchOpenAIModels(client, ctx, request)
}
return allModels, nil
}