package sqlfs2
import (
"crypto/tls"
"database/sql"
"fmt"
"regexp"
"strings"
"github.com/c4pt0r/agfs/agfs-server/pkg/plugin/config"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
log "github.com/sirupsen/logrus"
)
type TiDBBackend struct{}
func (b *TiDBBackend) Name() string {
return "tidb"
}
func (b *TiDBBackend) Initialize(cfg map[string]interface{}) (*sql.DB, error) {
dsnStr := config.GetStringConfig(cfg, "dsn", "")
dsnHasTLS := strings.Contains(dsnStr, "tls=")
tlsConfigName := "tidb-sqlfs2"
if dsnHasTLS {
re := regexp.MustCompile(`tls=([^&]+)`)
if matches := re.FindStringSubmatch(dsnStr); len(matches) > 1 {
tlsConfigName = matches[1]
}
}
enableTLS := config.GetBoolConfig(cfg, "enable_tls", false) || dsnHasTLS
if enableTLS {
serverName := config.GetStringConfig(cfg, "tls_server_name", "")
if serverName == "" {
if dsnStr != "" {
re := regexp.MustCompile(`@tcp\(([^:]+):\d+\)`)
if matches := re.FindStringSubmatch(dsnStr); len(matches) > 1 {
serverName = matches[1]
}
} else {
serverName = config.GetStringConfig(cfg, "host", "")
}
}
skipVerify := config.GetBoolConfig(cfg, "tls_skip_verify", false)
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if serverName != "" {
tlsConfig.ServerName = serverName
}
if skipVerify {
tlsConfig.InsecureSkipVerify = true
log.Warn("[sqlfs2] TLS certificate verification is disabled (insecure)")
}
if err := mysql.RegisterTLSConfig(tlsConfigName, tlsConfig); err != nil {
log.Warnf("[sqlfs2] Failed to register TLS config (may already exist): %v", err)
}
}
var dsn string
if dsnStr != "" {
dsn = dsnStr
} else {
user := config.GetStringConfig(cfg, "user", "root")
password := config.GetStringConfig(cfg, "password", "")
host := config.GetStringConfig(cfg, "host", "127.0.0.1")
port := config.GetStringConfig(cfg, "port", "4000")
database := config.GetStringConfig(cfg, "database", "test")
if password != "" {
dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True",
user, password, host, port, database)
} else {
dsn = fmt.Sprintf("%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True",
user, host, port, database)
}
if enableTLS {
dsn += fmt.Sprintf("&tls=%s", tlsConfigName)
}
}
log.Infof("[sqlfs2] Connecting to TiDB (TLS: %v)", enableTLS)
dbName := extractDatabaseName(dsn, config.GetStringConfig(cfg, "database", ""))
if dbName != "" {
dsnWithoutDB := removeDatabaseFromDSN(dsn)
if dsnWithoutDB != dsn {
tempDB, err := sql.Open("mysql", dsnWithoutDB)
if err == nil {
defer tempDB.Close()
_, err = tempDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName))
if err != nil {
log.Warnf("[sqlfs2] Failed to create database '%s': %v", dbName, err)
}
}
}
}
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open TiDB database: %w", err)
}
db.SetMaxOpenConns(100)
db.SetMaxIdleConns(10)
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping TiDB database: %w", err)
}
return db, nil
}
func (b *TiDBBackend) GetTableSchema(db *sql.DB, dbName, tableName string) (string, error) {
if dbName != "" {
if err := b.SwitchDatabase(db, dbName); err != nil {
return "", err
}
}
var tblName, createTableStmt string
query := fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
err := db.QueryRow(query).Scan(&tblName, &createTableStmt)
if err != nil {
return "", fmt.Errorf("failed to get table schema: %w", err)
}
return createTableStmt, nil
}
func (b *TiDBBackend) ListDatabases(db *sql.DB) ([]string, error) {
rows, err := db.Query("SHOW DATABASES")
if err != nil {
return nil, fmt.Errorf("failed to list databases: %w", err)
}
defer rows.Close()
var databases []string
for rows.Next() {
var dbName string
if err := rows.Scan(&dbName); err != nil {
return nil, err
}
databases = append(databases, dbName)
}
return databases, nil
}
func (b *TiDBBackend) ListTables(db *sql.DB, dbName string) ([]string, error) {
if err := b.SwitchDatabase(db, dbName); err != nil {
return nil, err
}
rows, err := db.Query("SHOW TABLES")
if err != nil {
return nil, fmt.Errorf("failed to list tables: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, err
}
tables = append(tables, tableName)
}
return tables, nil
}
func (b *TiDBBackend) SwitchDatabase(db *sql.DB, dbName string) error {
if dbName == "" {
return nil
}
_, err := db.Exec(fmt.Sprintf("USE `%s`", dbName))
if err != nil {
return fmt.Errorf("failed to switch to database %s: %w", dbName, err)
}
return nil
}
func (b *TiDBBackend) GetTableColumns(db *sql.DB, dbName, tableName string) ([]ColumnInfo, error) {
if dbName != "" {
if err := b.SwitchDatabase(db, dbName); err != nil {
return nil, err
}
}
query := fmt.Sprintf("SHOW COLUMNS FROM `%s`", tableName)
rows, err := db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to get table columns: %w", err)
}
defer rows.Close()
var columns []ColumnInfo
for rows.Next() {
var field, colType string
var null, key, extra interface{}
var dflt interface{}
if err := rows.Scan(&field, &colType, &null, &key, &dflt, &extra); err != nil {
return nil, err
}
columns = append(columns, ColumnInfo{Name: field, Type: colType})
}
return columns, nil
}
func extractDatabaseName(dsn string, configDB string) string {
if dsn != "" {
re := regexp.MustCompile(`\)/([^?]+)`)
if matches := re.FindStringSubmatch(dsn); len(matches) > 1 {
return matches[1]
}
}
return configDB
}
func removeDatabaseFromDSN(dsn string) string {
re := regexp.MustCompile(`\)/[^?]+(\?|$)`)
return re.ReplaceAllString(dsn, ")/$1")
}