package gemini
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"github.com/bestruirui/octopus/internal/transformer/model"
"github.com/bestruirui/octopus/internal/utils/xurl"
"github.com/samber/lo"
)
type MessagesOutbound struct{}
func (o *MessagesOutbound) TransformRequest(ctx context.Context, request *model.InternalLLMRequest, baseUrl, key string) (*http.Request, error) {
geminiReq := convertLLMToGeminiRequest(request)
body, err := json.Marshal(geminiReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal gemini request: %w", err)
}
parsedUrl, err := url.Parse(strings.TrimSuffix(baseUrl, "/"))
if err != nil {
return nil, fmt.Errorf("failed to parse base url: %w", err)
}
isStream := request.Stream != nil && *request.Stream
method := "generateContent"
if isStream {
method = "streamGenerateContent"
}
modelName := request.Model
if !strings.Contains(modelName, "/") {
modelName = "models/" + modelName
}
parsedUrl.Path = fmt.Sprintf("%s/%s:%s", parsedUrl.Path, modelName, method)
q := parsedUrl.Query()
q.Set("key", key)
if isStream {
q.Set("alt", "sse")
}
parsedUrl.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parsedUrl.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
return req, nil
}
func (o *MessagesOutbound) TransformResponse(ctx context.Context, response *http.Response) (*model.InternalLLMResponse, error) {
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if len(body) == 0 {
return nil, fmt.Errorf("response body is empty")
}
var geminiResp model.GeminiGenerateContentResponse
if err := json.Unmarshal(body, &geminiResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal gemini response: %w", err)
}
return convertGeminiToLLMResponse(&geminiResp, false), nil
}
func (o *MessagesOutbound) TransformStream(ctx context.Context, eventData []byte) (*model.InternalLLMResponse, error) {
if bytes.HasPrefix(eventData, []byte("[DONE]")) || len(eventData) == 0 {
return &model.InternalLLMResponse{
Object: "[DONE]",
}, nil
}
var geminiResp model.GeminiGenerateContentResponse
if err := json.Unmarshal(eventData, &geminiResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal gemini stream chunk: %w", err)
}
return convertGeminiToLLMResponse(&geminiResp, true), nil
}
func reasoningToThinkingBudget(effort string) int32 {
switch strings.ToLower(effort) {
case "low":
return 1024
case "medium":
return 4096
case "high":
return 24576
default:
return -1
}
}
func audioTypeToMimeType(format string) string {
switch format {
case "wav":
return "audio/wav"
case "mp3":
return "audio/mp3"
case "aiff":
return "audio/aiff"
case "aac":
return "audio/aac"
case "ogg":
return "audio/ogg"
case "flac":
return "audio/flac"
default:
return "audio/wav"
}
}
func convertLLMToGeminiRequest(request *model.InternalLLMRequest) *model.GeminiGenerateContentRequest {
geminiReq := &model.GeminiGenerateContentRequest{
Contents: []*model.GeminiContent{},
}
var systemInstruction *model.GeminiContent
for _, msg := range request.Messages {
switch msg.Role {
case "system", "developer":
if systemInstruction == nil {
systemInstruction = &model.GeminiContent{
Parts: []*model.GeminiPart{},
}
}
if msg.Content.Content != nil {
systemInstruction.Parts = append(systemInstruction.Parts, &model.GeminiPart{
Text: *msg.Content.Content,
})
}
case "user":
content := &model.GeminiContent{
Role: "user",
Parts: []*model.GeminiPart{},
}
if msg.Content.Content != nil {
content.Parts = append(content.Parts, &model.GeminiPart{
Text: *msg.Content.Content,
})
}
if msg.Content.MultipleContent != nil {
for _, part := range msg.Content.MultipleContent {
switch part.Type {
case "text":
if part.Text != nil {
content.Parts = append(content.Parts, &model.GeminiPart{
Text: *part.Text,
})
}
case "image_url":
dataurl := xurl.ParseDataURL(part.ImageURL.URL)
if dataurl != nil && dataurl.IsBase64 {
content.Parts = append(content.Parts, &model.GeminiPart{
InlineData: &model.GeminiBlob{
MimeType: dataurl.MediaType,
Data: dataurl.Data,
},
})
}
case "input_audio":
if part.Audio != nil {
content.Parts = append(content.Parts, &model.GeminiPart{
InlineData: &model.GeminiBlob{
MimeType: audioTypeToMimeType(part.Audio.Format),
Data: part.Audio.Data,
},
})
}
case "file":
if part.File != nil {
dataurl := xurl.ParseDataURL(part.File.FileData)
if dataurl != nil && dataurl.IsBase64 {
content.Parts = append(content.Parts, &model.GeminiPart{
InlineData: &model.GeminiBlob{
MimeType: dataurl.MediaType,
Data: dataurl.Data,
},
})
}
}
}
}
}
geminiReq.Contents = append(geminiReq.Contents, content)
case "assistant":
content := &model.GeminiContent{
Role: "model",
Parts: []*model.GeminiPart{},
}
if msg.Content.Content != nil && *msg.Content.Content != "" {
content.Parts = append(content.Parts, &model.GeminiPart{
Text: *msg.Content.Content,
})
}
if len(msg.ToolCalls) > 0 {
for _, toolCall := range msg.ToolCalls {
var args map[string]interface{}
_ = json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
content.Parts = append(content.Parts, &model.GeminiPart{
FunctionCall: &model.GeminiFunctionCall{
Name: toolCall.Function.Name,
Args: args,
},
ThoughtSignature: "skip_thought_signature_validator",
})
}
}
geminiReq.Contents = append(geminiReq.Contents, content)
case "tool":
content := convertLLMToolResultToGeminiContent(&msg)
geminiReq.Contents = append(geminiReq.Contents, content)
}
}
geminiReq.SystemInstruction = systemInstruction
config := &model.GeminiGenerationConfig{}
hasConfig := false
if request.MaxTokens != nil {
config.MaxOutputTokens = int(*request.MaxTokens)
hasConfig = true
}
if request.Temperature != nil {
config.Temperature = request.Temperature
hasConfig = true
}
if request.TopP != nil {
config.TopP = request.TopP
hasConfig = true
}
if topKStr, ok := request.TransformerMetadata["gemini_top_k"]; ok {
var topK int
fmt.Sscanf(topKStr, "%d", &topK)
config.TopK = &topK
hasConfig = true
}
if request.Stop != nil && request.Stop.MultipleStop != nil {
config.StopSequences = request.Stop.MultipleStop
hasConfig = true
} else if request.Stop != nil && request.Stop.Stop != nil {
config.StopSequences = []string{*request.Stop.Stop}
hasConfig = true
}
if request.ReasoningEffort != "" {
budget := reasoningToThinkingBudget(request.ReasoningEffort)
config.ThinkingConfig = &model.GeminiThinkingConfig{
ThinkingBudget: &budget,
IncludeThoughts: true,
}
hasConfig = true
}
if request.ResponseFormat != nil {
switch request.ResponseFormat.Type {
case "json_object":
config.ResponseMimeType = "application/json"
hasConfig = true
case "json_schema":
config.ResponseMimeType = "application/json"
hasConfig = true
case "text":
config.ResponseMimeType = "text/plain"
hasConfig = true
}
}
if len(request.Modalities) > 0 {
convertedModalities := make([]string, len(request.Modalities))
for i, m := range request.Modalities {
if len(m) > 0 {
convertedModalities[i] = strings.ToUpper(m[:1]) + strings.ToLower(m[1:])
}
}
config.ResponseModalities = convertedModalities
hasConfig = true
}
if hasConfig {
geminiReq.GenerationConfig = config
}
if safetyJSON, ok := request.TransformerMetadata["gemini_safety_settings"]; ok {
var safetySettings []*model.GeminiSafetySetting
if err := json.Unmarshal([]byte(safetyJSON), &safetySettings); err == nil {
geminiReq.SafetySettings = safetySettings
}
}
if len(request.Tools) > 0 {
functionDeclarations := make([]*model.GeminiFunctionDeclaration, 0, len(request.Tools))
for _, tool := range request.Tools {
if tool.Type != "function" {
continue
}
var params map[string]any
if len(tool.Function.Parameters) > 0 {
_ = json.Unmarshal(tool.Function.Parameters, ¶ms)
}
cleanGeminiSchema(params)
functionDeclarations = append(functionDeclarations, &model.GeminiFunctionDeclaration{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: params,
})
}
if len(functionDeclarations) > 0 {
geminiReq.Tools = []*model.GeminiTool{{FunctionDeclarations: functionDeclarations}}
}
}
if request.ToolChoice != nil {
mode := "AUTO"
var allowed []string
if request.ToolChoice.ToolChoice != nil {
switch strings.ToLower(*request.ToolChoice.ToolChoice) {
case "auto":
mode = "AUTO"
case "required":
mode = "ANY"
case "none":
mode = "NONE"
}
} else if request.ToolChoice.NamedToolChoice != nil && request.ToolChoice.NamedToolChoice.Type == "function" {
mode = "ANY"
if request.ToolChoice.NamedToolChoice.Function.Name != "" {
allowed = []string{request.ToolChoice.NamedToolChoice.Function.Name}
}
}
geminiReq.ToolConfig = &model.GeminiToolConfig{
FunctionCallingConfig: &model.GeminiFunctionCallingConfig{
Mode: mode,
AllowedFunctionNames: allowed,
},
}
}
return geminiReq
}
func convertLLMToolResultToGeminiContent(msg *model.Message) *model.GeminiContent {
content := &model.GeminiContent{
Role: "user",
}
var responseData map[string]any
if msg.Content.Content != nil {
_ = json.Unmarshal([]byte(*msg.Content.Content), &responseData)
}
if responseData == nil {
responseData = map[string]any{"result": lo.FromPtrOr(msg.Content.Content, "")}
}
fp := &model.GeminiFunctionResponse{
Name: lo.FromPtrOr(msg.ToolCallID, ""),
Response: responseData,
}
content.Parts = []*model.GeminiPart{
{FunctionResponse: fp},
}
return content
}
func convertGeminiToLLMResponse(geminiResp *model.GeminiGenerateContentResponse, isStream bool) *model.InternalLLMResponse {
resp := &model.InternalLLMResponse{
Choices: []model.Choice{},
}
if isStream {
resp.Object = "chat.completion.chunk"
} else {
resp.Object = "chat.completion"
}
for _, candidate := range geminiResp.Candidates {
choice := model.Choice{
Index: candidate.Index,
}
if candidate.FinishReason != nil {
reason := convertGeminiFinishReason(*candidate.FinishReason)
choice.FinishReason = &reason
}
if candidate.Content != nil {
msg := &model.Message{
Role: "assistant",
}
var textParts []string
var contentParts []model.MessageContentPart
var toolCalls []model.ToolCall
var reasoningContent *string
var hasInlineData bool
for idx, part := range candidate.Content.Parts {
if part.Thought {
if part.Text != "" && reasoningContent == nil {
reasoningContent = &part.Text
}
} else if part.Text != "" {
textParts = append(textParts, part.Text)
text := part.Text
contentParts = append(contentParts, model.MessageContentPart{
Type: "text",
Text: &text,
})
}
if part.InlineData != nil {
hasInlineData = true
dataURL := fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data)
contentParts = append(contentParts, model.MessageContentPart{
Type: "image_url",
ImageURL: &model.ImageURL{
URL: dataURL,
},
})
}
if part.FunctionCall != nil {
argsJSON, _ := json.Marshal(part.FunctionCall.Args)
toolCall := model.ToolCall{
Index: idx,
ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, idx),
Type: "function",
Function: model.FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(argsJSON),
},
}
toolCalls = append(toolCalls, toolCall)
}
}
if hasInlineData {
msg.Content = model.MessageContent{
MultipleContent: contentParts,
}
} else if len(textParts) > 0 {
text := strings.Join(textParts, "")
msg.Content = model.MessageContent{
Content: &text,
}
}
if reasoningContent != nil {
msg.ReasoningContent = reasoningContent
}
if len(toolCalls) > 0 {
msg.ToolCalls = toolCalls
if choice.FinishReason == nil {
reason := "tool_calls"
choice.FinishReason = &reason
}
}
if isStream {
choice.Delta = msg
} else {
choice.Message = msg
}
}
resp.Choices = append(resp.Choices, choice)
}
if geminiResp.UsageMetadata != nil {
usage := &model.Usage{
PromptTokens: int64(geminiResp.UsageMetadata.PromptTokenCount),
CompletionTokens: int64(geminiResp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int64(geminiResp.UsageMetadata.TotalTokenCount),
}
if geminiResp.UsageMetadata.CachedContentTokenCount > 0 {
if usage.PromptTokensDetails == nil {
usage.PromptTokensDetails = &model.PromptTokensDetails{}
}
usage.PromptTokensDetails.CachedTokens = int64(geminiResp.UsageMetadata.CachedContentTokenCount)
}
if geminiResp.UsageMetadata.ThoughtsTokenCount > 0 {
if usage.CompletionTokensDetails == nil {
usage.CompletionTokensDetails = &model.CompletionTokensDetails{}
}
usage.CompletionTokensDetails.ReasoningTokens = int64(geminiResp.UsageMetadata.ThoughtsTokenCount)
}
resp.Usage = usage
}
return resp
}
func convertGeminiFinishReason(reason string) string {
switch reason {
case "STOP":
return "stop"
case "MAX_TOKENS":
return "length"
case "SAFETY":
return "content_filter"
case "RECITATION":
return "content_filter"
default:
return "stop"
}
}
func cleanGeminiSchema(schema map[string]any) {
if schema == nil {
return
}
t := &geminiSchemaTransformer{
root: schema,
visited: map[uintptr]struct{}{},
}
t.transform(schema)
}
type geminiSchemaTransformer struct {
root map[string]any
visited map[uintptr]struct{}
}
func (t *geminiSchemaTransformer) transform(schemaNode any) {
if schemaNode == nil {
return
}
rv := reflect.ValueOf(schemaNode)
switch rv.Kind() {
case reflect.Map, reflect.Slice, reflect.Pointer:
if rv.IsNil() {
return
}
id := rv.Pointer()
if _, seen := t.visited[id]; seen {
return
}
t.visited[id] = struct{}{}
}
switch node := schemaNode.(type) {
case []any:
for _, item := range node {
t.transform(item)
}
return
case map[string]any:
if ref, ok := node["$ref"].(string); ok && strings.HasPrefix(ref, "#/") {
path := strings.Split(ref[2:], "/")
var cur any = t.root
for _, seg := range path {
seg = strings.ReplaceAll(seg, "~1", "/")
seg = strings.ReplaceAll(seg, "~0", "~")
m, ok := cur.(map[string]any)
if !ok {
cur = nil
break
}
cur = m[seg]
if cur == nil {
break
}
}
if resolved, ok := cur.(map[string]any); ok && resolved != nil {
overlay := make(map[string]any, len(node))
for k, v := range node {
if k != "$ref" {
overlay[k] = v
}
}
var copied map[string]any
if b, err := json.Marshal(resolved); err == nil {
_ = json.Unmarshal(b, &copied)
}
if copied == nil {
copied = make(map[string]any, len(resolved))
for k, v := range resolved {
copied[k] = v
}
}
for k := range node {
delete(node, k)
}
for k, v := range copied {
node[k] = v
}
for k, v := range overlay {
node[k] = v
}
delete(node, "$ref")
}
}
if allOf, ok := node["allOf"].([]any); ok {
for _, item := range allOf {
t.transform(item)
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
if itemProps, ok := itemMap["properties"].(map[string]any); ok {
props, _ := node["properties"].(map[string]any)
if props == nil {
props = map[string]any{}
}
for k, v := range itemProps {
if _, exists := props[k]; !exists {
props[k] = v
}
}
node["properties"] = props
}
itemReq := t.asStringSlice(itemMap["required"])
if len(itemReq) > 0 {
curReq := t.asStringSlice(node["required"])
curReq = append(curReq, itemReq...)
node["required"] = t.dedupeStrings(curReq)
}
}
delete(node, "allOf")
}
if typ, ok := node["type"]; ok {
primary := ""
switch v := typ.(type) {
case string:
primary = v
case []any:
for _, it := range v {
if s, ok := it.(string); ok && strings.ToLower(s) != "null" {
primary = s
break
}
}
case []string:
for _, s := range v {
if strings.ToLower(s) != "null" {
primary = s
break
}
}
}
switch strings.ToLower(primary) {
case "string":
node["type"] = "STRING"
case "number":
node["type"] = "NUMBER"
case "integer":
node["type"] = "INTEGER"
case "boolean":
node["type"] = "BOOLEAN"
case "array":
node["type"] = "ARRAY"
case "object":
node["type"] = "OBJECT"
}
}
if node["type"] == "ARRAY" {
if node["items"] == nil {
node["items"] = map[string]any{}
} else if tuple, ok := node["items"].([]any); ok {
for _, it := range tuple {
t.transform(it)
}
tupleTypes := make([]string, 0, len(tuple))
for _, it := range tuple {
if itMap, ok := it.(map[string]any); ok {
if tt, ok := itMap["type"].(string); ok && tt != "" {
tupleTypes = append(tupleTypes, tt)
} else {
tupleTypes = append(tupleTypes, "any")
}
} else {
tupleTypes = append(tupleTypes, "any")
}
}
hint := fmt.Sprintf("(Tuple: [%s])", strings.Join(tupleTypes, ", "))
if origDesc, _ := node["description"].(string); origDesc == "" {
node["description"] = hint
} else {
node["description"] = strings.TrimSpace(origDesc + " " + hint)
}
firstType := ""
if len(tuple) > 0 {
if itMap, ok := tuple[0].(map[string]any); ok {
firstType, _ = itMap["type"].(string)
}
}
isHomogeneous := firstType != ""
for _, it := range tuple {
itMap, ok := it.(map[string]any)
if !ok {
isHomogeneous = false
break
}
tt, _ := itMap["type"].(string)
if tt != firstType {
isHomogeneous = false
break
}
}
if isHomogeneous {
node["items"] = tuple[0]
} else {
node["items"] = map[string]any{}
}
}
}
if anyOf, ok := node["anyOf"].([]any); ok {
for _, item := range anyOf {
t.transform(item)
}
allConst := true
enumVals := make([]string, 0, len(anyOf))
for _, item := range anyOf {
itemMap, ok := item.(map[string]any)
if !ok {
allConst = false
break
}
c, ok := itemMap["const"]
if !ok {
allConst = false
break
}
if c == nil || c == "" {
continue
}
enumVals = append(enumVals, fmt.Sprint(c))
}
if allConst && len(enumVals) > 0 {
node["type"] = "STRING"
node["enum"] = enumVals
} else if _, hasType := node["type"]; !hasType {
for _, item := range anyOf {
if itemMap, ok := item.(map[string]any); ok {
if itemMap["type"] != nil || itemMap["enum"] != nil {
for k, v := range itemMap {
node[k] = v
}
break
}
}
}
}
delete(node, "anyOf")
}
if def, ok := node["default"]; ok {
if desc, ok := node["description"].(string); ok && desc != "" {
if b, err := json.Marshal(def); err == nil {
node["description"] = desc + " (Default: " + string(b) + ")"
}
}
}
for _, k := range []string{
"title", "$schema", "$ref", "strict",
"exclusiveMaximum", "exclusiveMinimum",
"additionalProperties", "oneOf", "default",
"$defs", "propertyNames", "pattern", "minLength",
"maxLength", "minimum", "maximum", "maxItems", "minItems",
"uniqueItems", "multipleOf",
} {
delete(node, k)
}
if props, ok := node["properties"].(map[string]any); ok {
for _, prop := range props {
t.transform(prop)
}
}
if items := node["items"]; items != nil {
t.transform(items)
}
if req := t.asStringSlice(node["required"]); len(req) > 0 {
node["required"] = t.dedupeStrings(req)
}
}
}
func (t *geminiSchemaTransformer) asStringSlice(v any) []string {
switch s := v.(type) {
case []string:
return append([]string(nil), s...)
case []any:
out := make([]string, 0, len(s))
for _, it := range s {
if str, ok := it.(string); ok {
out = append(out, str)
}
}
return out
default:
return nil
}
}
func (t *geminiSchemaTransformer) dedupeStrings(in []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(in))
for _, s := range in {
if s == "" {
continue
}
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
out = append(out, s)
}
return out
}