/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
 */
package magic.tool

import magic.dsl.jsonable
import magic.core.tool.*
import magic.core.model.ChatModel
import magic.core.message.ChatMessage
import magic.config.Config
import magic.log.LogUtils
import magic.jsonable.*
import magic.model.{ModelManager, ModelUtils}
import magic.prompt.Template
import magic.parser.OutputParserUtils
import magic.vdb.{SemanticMap, InMemoryVectorDatabase, SimpleIndexMap}

import std.collection.{ArrayList, HashMap, map, collectArray}
import encoding.json.*

@jsonable
private class Response {
    let relevantTools: Array<FoundTool>
}

@jsonable
private class FoundTool {
    let name: String
    let reason: String
}

private const FILTER_TOOL_PROMPT = """
Objective: Select the most relevant tool(s) for a given user question from a predefined list.

# Instructions:

Analyze the user's question and the provided list of tools.
For each tool, evaluate its relevance to the question based on its name and description.
Return the top {number} most relevant tools (if any), ranked by relevance. If no tool fits, return "None."
Justify your selections with a brief reasoning.

# Input:

User Question: {question}
Available Tools:

{tools}
...

# Output Structure:

```json
{
  "relevantTools": [
    {
      "name": "[Selected Tool Name]",
      "reason": "[Brief justification]"
    },
    ... 
  ]
}
```

# Example Input/Output

User Question: "How can I extract text from a PDF?"
Available Tools:

PDF Text Extractor: Extracts raw text from PDFs while preserving formatting.
-----
Image OCR: Converts text in images/scans to machine-readable text.
-----
PDF Editor: Edits PDF metadata, merges/splits pages, but does not extract text.

Output:
```json
{
  "relevantTools": [
    {
      "name": "PDF Text Extractor",
      "reason": "Directly matches the task of extracting text from PDFs."
    },
    {
      "name": "Image OCR",
      "reason": "May be useful if the PDF contains scanned images of text."
    }
  ]
}
```
"""

private type InMemSemanticMap = SemanticMap<InMemoryVectorDatabase, SimpleIndexMap, String>

public class SimpleToolManager <: ToolManager {
    private let tools = HashMap<String, Tool>()

    private let _enableFilter: Bool

    public init() {
        _enableFilter = false
    }

    public init(tools: Collection<Tool>, enableFilter!: Bool = false) {
        this._enableFilter = enableFilter
        for (t in tools) {
            addTool(t)
        }
    }

    override public prop enableFilter: Bool {
        get() { return _enableFilter }
    }

    override public func addTool(tool: Tool): Unit {
        tools.put(tool.name, tool)
    }

    override public func delTool(tool: Tool): Unit {
        if (tools.contains(tool.name)) {
            tools.remove(tool.name)
        }
    }

    override public func addTools(tools: Array<Tool>): Unit {
        for (tool in tools) {
            this.addTool(tool)
        }
    }

    override public func clear(): Unit {
        this.tools.clear()
    }

    override public func findTool(name: String): Option<Tool> {
        return tools.get(name)
    }

    override public func getTools(): Array<Tool> {
        return tools.values().toArray()
    }

    private func splitTools(pred: (Tool) -> Bool): (ArrayList<Tool>, ArrayList<Tool>) {
        let groupA = ArrayList<Tool>()
        let groupB = ArrayList<Tool>()
        for (tool in this.getTools()) {
            if (pred(tool)) {
                groupA.append(tool)
            } else {
                groupB.append(tool)
            }
        }
        return (groupA, groupB)
    }

    override public func filterTool(question: String, config: ToolSearchConfig): Array<Tool> {
        if (!this.enableFilter) {
            throw ToolException("Tool manager does not enable filter")
        }
        let (specialTools, normalTools) = this.splitTools({ tool =>
            (tool.extra.get("filterable") ?? "") == "false"
        })
        let filteredTools = if (config.viaEmbedding) {
            this.filterToolViaEmbedding(question, normalTools, config)
        } else {
            this.filterToolViaLLM(question, normalTools, config)
        }
        LogUtils.info("Filtered tools: ${normalTools.size} -> ${filteredTools.size}")
        // Special tools are put first
        specialTools.appendAll(filteredTools)
        return specialTools.toArray()
    }

    private func filterToolViaEmbedding(question: String,
                                        tools: ArrayList<Tool>,
                                        config: ToolSearchConfig): Array<Tool> {
        if (config.embeddingModel.isNone()) {
            throw ToolException("An embedding model is not specified")
        }
        // Do not build the semantic map from scratch
        // Change it later
        let semMap = InMemSemanticMap(
            vectorDB: InMemoryVectorDatabase(),
            indexMap: SimpleIndexMap(),
            embeddingModel: config.embeddingModel.getOrThrow()
        )
        // Put tools in the semantic map
        for (tool in tools) {
            semMap.put(tool.description, tool.name)
        }

        let filteredTools = ArrayList<Tool>()
        for (name in semMap.search(question, number: config.number)) {
            if (let Some(t) <- this.findTool(name)) {
                filteredTools.append(t)
            }
        }
        return filteredTools.toArray()
    }

    private func filterToolViaLLM(question: String,
                                  tools: ArrayList<Tool>,
                                  config: ToolSearchConfig): Array<Tool> {
        if (config.chatModel.isNone()) {
            throw ToolException("A chat model is not specified")
        }
        let toolPrompts = ArrayList<String>()
        for (tool in tools) {
            toolPrompts.append("${tool.name}: ${tool.description}")
        }
        let messages = [
            ChatMessage.system(
                FILTER_TOOL_PROMPT.format(
                    ("number", config.number),
                    ("question", question),
                    ("tools", String.join(toolPrompts.toArray(), delimiter: "\n-----\n"))
                )
            )
        ]
        let model = config.chatModel.getOrThrow()
        let response = ModelUtils.makeChatGet<Response>(
            "Simple Tool Manager", model, messages) { msg: ChatMessage =>
            if (let Some(jsonStr) <- OutputParserUtils.extractLastCode(msg.content, "json")) {
                // LogUtils.info("Tool filtered response: ${jsonStr}")
                try {
                    let json = JsonValue.fromStr(jsonStr)
                    return Response.fromJsonValue(json)
                } catch (_: JsonException) {
                    return None
                }
            }
            LogUtils.error("Failed to parse the chat model response")
            return None
        }
        let filteredTools = ArrayList<Tool>()
        for (tool in (response?.relevantTools) ?? []) {
            if (let Some(t) <- this.findTool(tool.name)) {
                filteredTools.append(t)
            } else {
                LogUtils.error("Invalid tool after being filtered: `${tool.name}`")
            }
        }
        return filteredTools.toArray()
    }
}