import unittest
from unittest.mock import patch, Mock
from django.test import TestCase
from sql.models import Instance
from sql.engines.cassandra import CassandraEngine, split_sql
from sql.engines.models import ResultSet
# 启用后, 会运行全部测试, 包括一些集成测试
integration_test_enabled = False
integration_test_host = "localhost"
class CassandraEngineTest(TestCase):
def setUp(self) -> None:
self.ins = Instance.objects.create(
instance_name="some_ins",
type="slave",
db_type="cassandra",
host="localhost",
port=9200,
user="cassandra",
password="cassandra",
db_name="some_db",
)
self.engine = CassandraEngine(instance=self.ins)
def tearDown(self) -> None:
self.ins.delete()
@patch("sql.engines.cassandra.Cluster.connect")
def test_get_connection(self, mock_connect):
_ = self.engine.get_connection()
mock_connect.assert_called_once()
@patch("sql.engines.cassandra.CassandraEngine.get_connection")
def test_query(self, mock_get_connection):
test_sql = """select 123"""
self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet)
def test_query_check(self):
test_sql = """select 123; -- this is comment
select 456;"""
result_sql = "select 123;"
check_result = self.engine.query_check(sql=test_sql)
self.assertIsInstance(check_result, dict)
self.assertEqual(False, check_result.get("bad_query"))
self.assertEqual(result_sql, check_result.get("filtered_sql"))
def test_query_check_error(self):
test_sql = """drop table table_a"""
check_result = self.engine.query_check(sql=test_sql)
self.assertIsInstance(check_result, dict)
self.assertEqual(True, check_result.get("bad_query"))
@patch("sql.engines.cassandra.CassandraEngine.query")
def test_get_all_databases(self, mock_query):
mock_query.return_value = ResultSet(rows=[("some_db",)])
result = self.engine.get_all_databases()
self.assertIsInstance(result, ResultSet)
self.assertEqual(result.rows, ["some_db"])
@patch("sql.engines.cassandra.CassandraEngine.query")
def test_get_all_tables(self, mock_query):
# 下面是查表示例返回结果
mock_query.return_value = ResultSet(rows=[("u",), ("v",), ("w",)])
table_list = self.engine.get_all_tables("some_db")
self.assertEqual(table_list.rows, ["u", "v", "w"])
@patch("sql.engines.cassandra.CassandraEngine.query")
def test_describe_table(self, mock_query):
mock_query.return_value = ResultSet()
self.engine.describe_table("some_db", "some_table")
mock_query.assert_called_once_with(
db_name="some_db", sql="describe table some_table"
)
@patch("sql.engines.cassandra.CassandraEngine.query")
def test_get_all_columns_by_tb(self, mock_query):
mock_query.return_value = ResultSet(
rows=[("name",)], column_list=["column_name"]
)
result = self.engine.get_all_columns_by_tb("some_db", "some_table")
self.assertEqual(result.rows, ["name"])
self.assertEqual(result.column_list, ["column_name"])
def test_split(self):
sql = """CREATE TABLE emp(
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
sql_result = split_sql(db_name="test_db", sql=sql)
self.assertEqual(sql_result[0], "USE test_db")
def test_execute_check(self):
sql = """CREATE TABLE emp(
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
check_result = self.engine.execute_check(db_name="test_db", sql=sql)
self.assertEqual(check_result.full_sql, sql)
self.assertEqual(check_result.rows[1].stagestatus, "Audit completed")
@patch("sql.engines.cassandra.CassandraEngine.get_connection")
def test_execute(self, mock_connection):
mock_execute = Mock()
mock_connection.return_value.execute = mock_execute
sql = """CREATE TABLE emp(
emp_id int PRIMARY KEY,
emp_name text,
emp_city text,
emp_sal varint,
emp_phone varint
);"""
execute_result = self.engine.execute(db_name="test_db", sql=sql)
self.assertEqual(execute_result.rows[1].stagestatus, "Execute Successfully")
mock_execute.assert_called()
# exception
mock_execute.side_effect = ValueError("foo")
mock_execute.reset_mock(return_value=True)
execute_result = self.engine.execute(db_name="test_db", sql=sql)
self.assertEqual(execute_result.rows[0].stagestatus, "Execute Failed")
self.assertEqual(execute_result.rows[1].stagestatus, "Execute Failed")
self.assertEqual(execute_result.rows[0].errormessage, "异常信息:foo")
self.assertEqual(execute_result.rows[1].errormessage, "前序语句失败, 未执行")
mock_execute.assert_called()
def test_filter_sql(self):
sql_without_limit = "select name from user_info;"
self.assertEqual(
self.engine.filter_sql(sql_without_limit, limit_num=100),
"select name from user_info limit 100;",
)
sql_with_normal_limit = "select name from user_info limit 1;"
self.assertEqual(
self.engine.filter_sql(sql_with_normal_limit, limit_num=100),
"select name from user_info limit 1;",
)
sql_with_high_limit = "select name from user_info limit 1000;"
self.assertEqual(
self.engine.filter_sql(sql_with_high_limit, limit_num=100),
"select name from user_info limit 100;",
)
@unittest.skipIf(
not integration_test_enabled, "cassandra integration test is not enabled"
)
class CassandraIntegrationTest(TestCase):
def setUp(self):
self.instance = Instance.objects.create(
instance_name="int_ins",
type="slave",
db_type="cassandra",
host=integration_test_host,
port=9042,
user="cassandra",
password="cassandra",
db_name="",
)
self.engine = CassandraEngine(instance=self.instance)
self.keyspace = "test"
self.table = "test_table"
# 新建 keyspace
self.engine.execute(
sql=f"create keyspace {self.keyspace} with replication = "
"{'class': 'org.apache.cassandra.locator.SimpleStrategy', "
"'replication_factor': '1'};"
)
# 建表
self.engine.execute(
db_name=self.keyspace,
sql=f"""create table if not exists {self.table}( name text primary key );""",
)
def tearDown(self):
self.engine.execute(sql="drop keyspace test;")
def test_integrate_query(self):
self.engine.execute(
db_name=self.keyspace,
sql=f"insert into {self.table} (name) values ('test')",
)
result = self.engine.query(
db_name=self.keyspace, sql=f"select * from {self.table}"
)
self.assertEqual(result.rows[0][0], "test")