package pq
import (
"database/sql"
"fmt"
"log"
"math"
"regexp"
"sync"
"strings"
)
type QueryResult struct {
Index int
Rows []map[string]interface{}
Err error
}
func validateQuery(query string) {
if strings.TrimSpace(query) == "" {
log.Fatal("invalid SQL query: must be a non-empty string")
}
trimmedQuery := strings.TrimSpace(query)
selectRegex := regexp.MustCompile(`(?i)^\s*(?:--.*\s*)*SELECT\s+`)
if !selectRegex.MatchString(trimmedQuery) {
log.Fatal("invalid SQL query: must be a SELECT statement")
}
noComments := regexp.MustCompile(`--.*$`).ReplaceAllString(trimmedQuery, "")
statements := strings.Split(noComments, ";")
validCount := 0
for _, stmt := range statements {
if strings.TrimSpace(stmt) != "" {
validCount++
}
}
if validCount != 1 {
log.Fatalf("invalid SQL query: must contain exactly one query, found %d", validCount)
}
vectorOpRegex := regexp.MustCompile(`<->|<=>|<#>|<\+>|<~>|<%>`)
if !vectorOpRegex.MatchString(trimmedQuery) {
log.Fatal("invalid SQL query: must contain vector operator <->, <=>, <+>, <~>, <%> or <#>")
}
}
func buildSetSQL(scanParams map[string]interface{}) []string {
if len(scanParams) == 0 {
return nil
}
stmts := make([]string, 0, len(scanParams))
for k, v := range scanParams {
stmts = append(stmts, fmt.Sprintf("set %s=%v", k, v))
}
return stmts
}
func duplicate(conninfo string) (*sql.DB, error) {
return sql.Open("opengauss", conninfo)
}
func parseRows(rows *sql.Rows) ([]map[string]interface{}, error) {
cols, err := rows.Columns()
if err != nil {
return nil, err
}
var results []map[string]interface{}
for rows.Next() {
scanTargets := make([]interface{}, len(cols))
for i := range scanTargets {
var val interface{}
scanTargets[i] = &val
}
if err := rows.Scan(scanTargets...); err != nil {
return results, err
}
row := make(map[string]interface{})
for i, col := range cols {
val := *scanTargets[i].(*interface{})
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
results = append(results, row)
}
return results, rows.Err()
}
func ExecuteMultiSearch(
conninfo string,
query string,
args [][]interface{},
scanParams map[string]interface{},
threadCount int,
) ([][]map[string]interface{}) {
validateQuery(query)
if args == nil || len(args) == 0 {
log.Fatal("args can not be empty")
}
if (threadCount <= 0) {
log.Fatal("please confirm that the number of threads is greater than 0.")
}
var wg sync.WaitGroup
resultChan := make(chan QueryResult, len(args))
threadCount = int(math.Min(float64(len(args)), float64(threadCount)))
var connections []*sql.DB
for i := 0; i < threadCount; i++ {
newConn, err := duplicate(conninfo)
err = newConn.Ping()
if err != nil {
for _, conn := range connections {
if conn != nil {
conn.Close()
}
}
allResults := make([][]map[string]interface{}, 1)
allResults[0] = []map[string]interface{}{
{
"error": true,
"message": fmt.Sprintf("failed to create connection: %v", err),
"index": 0,
},
}
return allResults
}
connections = append(connections, newConn)
}
chunkSize := (len(args) + threadCount - 1) / threadCount
for i := 0; i < threadCount; i++ {
start := i * chunkSize
end := start + chunkSize
if end > len(args) {
end = len(args)
}
if start > end {
break
}
wg.Add(1)
conn := connections[i]
go func(start, end int, conn *sql.DB) {
defer wg.Done()
defer conn.Close()
setSQL := buildSetSQL(scanParams)
if setSQL != nil {
for _, sql := range setSQL {
_, err := conn.Exec(sql)
if err != nil {
for j := start; j < end; j++ {
resultChan <- QueryResult{Index: j, Err: fmt.Errorf("failed to execute setSQL: %v (sql: %s)", err, sql)}
}
return
}
}
}
for j := start; j < end; j++ {
arg := args[j]
rows, err := conn.Query(query, arg...)
if err != nil {
resultChan <- QueryResult{Index: j, Err: fmt.Errorf("query failed: %v", err)}
continue
}
parsedRows, err := parseRows(rows)
rows.Close()
if err != nil {
resultChan <- QueryResult{Index: j, Err: fmt.Errorf("parse failed: %v", err)}
continue
}
resultChan <- QueryResult{Index: j, Rows: parsedRows}
}
}(start, end, conn)
}
go func() {
wg.Wait()
close(resultChan)
}()
var allResults = make([][]map[string]interface{}, len(args))
for res := range resultChan {
if res.Err != nil {
errorResult := map[string]interface{}{
"error": true,
"message": res.Err.Error(),
"index": res.Index,
}
allResults[res.Index] = []map[string]interface{}{errorResult}
} else {
allResults[res.Index] = res.Rows
}
}
return allResults
}