import os
import time
import unittest
import psycopg2
from psycopg2.extras import (init_worker, close_connection, execute_single,
execute_multi_search, init_conn_pool, close_conn_pool)
from . import testutils
from .testutils import ConnectingTestCase
class ConnectInfo:
def __init__(self):
self.dbname = os.environ.get('PSYCOPG2_TESTDB', 'psycopg2_test')
self.dbhost = os.environ.get('PSYCOPG2_TESTDB_HOST', os.environ.get('PGHOST'))
self.dbport = os.environ.get('PSYCOPG2_TESTDB_PORT', os.environ.get('PGPORT'))
self.dbuser = os.environ.get('PSYCOPG2_TESTDB_USER', os.environ.get('PGUSER'))
self.dbpass = os.environ.get('PSYCOPG2_TESTDB_PASSWORD', os.environ.get('PGPASSWORD'))
self.sql_template = "SELECT * FROM test_table1 ORDER BY embedding <-> %s LIMIT %s;"
self.scan_params = {"enable_seqscan": "off", "hnsw_ef_search": 40}
self.dbconfig = {'user': self.dbuser, 'password': self.dbpass, 'host': self.dbhost,
'dbname': self.dbname, 'port': self.dbport}
class ParallelVectorScanTests(ConnectingTestCase):
"""Testing concurrent query capabilities with vector inputs."""
def setUp(self):
ConnectingTestCase.setUp(self)
conn = self.connect()
cur = conn.cursor()
cur.execute("create table if not exists test_table1 (id integer, embedding vector(3));")
cur.execute("insert into test_table1 values(1, '[1,2,3]'),(2, '[2,2,2]');")
conn.commit()
def test_init_worker(self):
info = ConnectInfo()
local_argslist = (('[1,2,2]', 2),)
init_worker(info.scan_params, info.dbconfig)
res = execute_single(local_argslist, info.sql_template)
self.assertEqual(len(res[0]), 2)
close_connection()
def test_invalid_conninfo(self):
info = ConnectInfo()
dbconfig = {'user': "test", 'password': info.dbpass, 'host': info.dbhost,
'dbname': info.dbname, 'port': info.dbport}
argslist = [('[1,1,1]', 1), ('[2,2,3]', 2)]
try:
res = execute_multi_search(dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
except Exception as e:
self.assertIsNotNone(e)
def test_check_connection_pool(self):
info = ConnectInfo()
argslist = [('[1,1,1]', 1), ('[2,2,3]', 2)]
conn_pool_mgr = init_conn_pool(info.dbconfig, 2, info.scan_params)
start_time = time.time()
while conn_pool_mgr.check_health():
if time.time() - start_time > 5:
break
res = execute_multi_search(None, conn_pool_mgr, info.sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
close_conn_pool(conn_pool_mgr)
def test_connection_pool(self):
argslist = [('[1,1,1]', 1), ('[2,2,3]', 2)]
info = ConnectInfo()
conn_pool_mgr = init_conn_pool(info.dbconfig, 2, info.scan_params)
res = execute_multi_search(info.dbconfig, conn_pool_mgr, info.sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
close_conn_pool(conn_pool_mgr)
def test_direct_connection(self):
argslist = [('[1,1,1]', 1), ('[2,2,3]', 2)]
info = ConnectInfo()
res = execute_multi_search(info.dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
def test_argslist(self):
info = ConnectInfo()
argslist = [('[1,1,1]')]
res = execute_multi_search(info.dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
self.assertIn("not all arguments converted during string formatting", str(res))
argslist = [(1)]
res = execute_multi_search(info.dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
self.assertIn("'int' object does not support indexing", str(res))
def test_distance(self):
info = ConnectInfo()
argslist = [('[1,1,1]', 1), ('[2,2,3]', 2)]
sql_template = "SELECT * FROM test_table1 ORDER BY embedding <#> %s LIMIT %s;"
res = execute_multi_search(info.dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
sql_template = "SELECT * FROM test_table1 ORDER BY embedding <=> %s LIMIT %s;"
res = execute_multi_search(info.dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
sql_template = "SELECT * FROM test_table1 ORDER BY embedding <+> %s LIMIT %s;"
res = execute_multi_search(info.dbconfig, None, info.sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
def test_sql_template(self):
info = ConnectInfo()
sql_template = "SELECT * FROM test_table1 ORDER BY embedding <-> %s LIMIT %s;\
SELECT * FROM test_table1 where id>%s ORDER BY embedding <-> %s LIMIT %s;"
argslist = [('[1,1,1]', 1, 1, '[2,3,4.5]', 10), ('[2,2,3]', 2, 100, '[5.6,7,8]', 10)]
try:
res = execute_multi_search(info.dbconfig, None, sql_template, argslist, info.scan_params, 2)
except Exception as e:
self.assertIn("Only one select statement is allowed", str(e))
sql_template = "SELECT * FROM test_table1 where id>%s ORDER BY embedding <-> %s LIMIT %s;"
argslist = [(1, '[2,3,4.5]', 1), (2, '[5.6,7,8]', 1)]
res = execute_multi_search(info.dbconfig, None, sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
sql_template = "SELECT * FROM (SELECT * FROM test_table1 ORDER BY embedding <-> %s LIMIT %s)\
ORDER BY embedding <=> %s LIMIT %s;"
argslist = [('[4,4,4]', 1, '[2,3,4.5]', 1), ('[2,2,3]', 2, '[5.6,7,8]', 1)]
res = execute_multi_search(info.dbconfig, None, sql_template, argslist, info.scan_params, 2)
self.assertEqual(len(res), 2)
def tearDown(self):
ConnectingTestCase.tearDown(self)
conn = self.connect()
cur = conn.cursor()
cur.execute("drop table if exists test_table1;")
conn.commit()
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()