/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
*/
package magic.rag
import magic.dsl.*
import magic.core.*
import magic.jsonable.*
import magic.agent.*
import magic.agent_executor.AgentExecutorManager
import magic.prompt.*
import magic.model.*
import magic.utils.*
import magic.memory.*
import magic.tool.*
import magic.log.LogUtils
import encoding.json.JsonValue
import std.collection.{ArrayList, HashMap}
private func doSqliteQuery(ppDb: CPointer<CPointer<Unit>>, query: String): String {
let ppStmt = SqliteUtils.sqlPrepare(ppDb, query)
try {
let result = ArrayList<String>()
let columnCount = SqliteUtils.sqlColumnCount(ppStmt)
while (SqliteUtils.sqlStep(ppStmt) != SqliteUtils.SQLITE_DONE) {
for (i in 0..columnCount) {
result.append(SqliteUtils.sqlColumnText(ppStmt, i))
}
}
return result.toString()
} finally {
SqliteUtils.sqlFinalize(ppStmt)
}
}
@agent[
tools: [executeQuery],
memory: false
]
class SqliteAgent {
var schema = ""
SqliteAgent(
let dbPath: String,
let table!: String = ""
) { }
@prompt[pattern: RISE](
role: "你是 SQLite 专家,根据用户问题从数据库中查找内容。",
input: "用户想要查找的问题",
steps: """
从数据库中获取想要查找的表的 schema 信息和对应的样本数据,
然后根据 schema 信息和对应的样本数据,生成正确的 SQLite 查询语句,执行查询语句,并返回查询结果。
如果查询失败,可以再次生成新的查询语句进行尝试;最多尝试次数5次。
如果用户没有指定查询个数,默认查找十个。
禁止直接回答用户问题,禁止修改用户查询结果。
""",
expectation: "生成正确的 SQLite 查询语句,执行查询语句,并直接返回 JSON 格式的查询结果,结果必须要包含 column 名。禁止修改用户查询结果。"
)
@tool[
description: "从数据库中获取表的 schema 信息和样本数据",
]
func getSchemaAndSampleData(): String {
if (this.schema != "") {
return this.schema
}
let sql = "SELECT name, sql FROM sqlite_master WHERE type ='table'"
let ppDb: CPointer<CPointer<Unit>> = SqliteUtils.sqlOpen(this.dbPath)
let ppStmt = SqliteUtils.sqlPrepare(ppDb, sql)
let result: ArrayList<String> = ArrayList<String>()
// If the table is specified, sample data from the table
// Otherwise, fetch all table schema
try {
if (this.table.isEmpty()) {
while (SqliteUtils.sqlStep(ppStmt) != SqliteUtils.SQLITE_DONE) {
result.append(SqliteUtils.sqlColumnText(ppStmt, 1))
}
} else {
let tables = this.table.split(",")
// Iterate all tables
while (SqliteUtils.sqlStep(ppStmt) != SqliteUtils.SQLITE_DONE) {
let currTable = SqliteUtils.sqlColumnText(ppStmt, 0)
if (tables.contains(currTable)) {
result.append(SqliteUtils.sqlColumnText(ppStmt, 1))
// LogUtils.info(currTable)
// Sample the first 5 rows
let data = doSqliteQuery(ppDb, "SELECT * FROM ${currTable} LIMIT 5")
result.append(data)
// LogUtils.info(data)
}
}
}
} finally {
SqliteUtils.sqlFinalize(ppStmt)
SqliteUtils.sqlClose(ppDb)
}
this.schema = result.toString()
// LogUtils.info(schema)
return this.schema
}
@tool[
description: "在数据库中执行对应的sql查询语句,然后返回查询结果",
parameters: {
query: "要执行的sql语句"
}
]
private func executeQuery(query: String): String {
let ppDb: CPointer<CPointer<Unit>> = SqliteUtils.sqlOpen(this.dbPath)
try {
return doSqliteQuery(ppDb, query)
} finally {
SqliteUtils.sqlClose(ppDb)
}
}
}