import pytest
from dbmind.common.rpc import RPCFunctionRegister, start_rpc_service, stop_rpc_service, ping_rpc_url
from dbmind.common.rpc import RPCJSONAble
from dbmind.common.rpc import RPCRequest, RPCResponse, RPCClient
from dbmind.common.rpc.errors import RPCExecutionError, RPCConnectionError
from dbmind.common.rpc.server import RPCExecutionThread
from dbmind.common.types import Sequence
def test_rpc_jsonable():
seq = Sequence(range(1, 10), range(1, 10), 'test_metric')
subseq = seq[1, 5]
seq_json = RPCJSONAble.serialize(seq)
seq_new = RPCJSONAble.deserialize(seq_json)
assert seq == seq_new
assert Sequence.from_json(seq.json()) == seq
subseq_json = RPCJSONAble.serialize(subseq)
subseq_new = RPCJSONAble.deserialize(subseq_json)
assert subseq == subseq_new
assert Sequence.from_json(subseq.json()) == subseq
dictionary = {'int': 1, 'float': 2.21, 'class': {'seq': seq},
'bool': True,
'list': [1, 2, '3']}
d = RPCJSONAble.serialize(dictionary)
assert RPCJSONAble.deserialize(d) == dictionary
def test_rpc_request_and_response():
request = RPCRequest('username', 'password', 'sum', fargs=[1, 2])
response = RPCResponse(request, True, 3)
assert request.__dict__ == RPCRequest.from_json(request.json()).__dict__
res_exp = response.__dict__
res_rst = RPCResponse.from_json(response.json()).__dict__
assert res_exp.pop('request').__dict__ == res_rst.pop('request').__dict__
assert res_exp == res_rst
@pytest.fixture(scope='session')
def rpc_client_testing():
mapper = RPCFunctionRegister()
@mapper.register()
def func(a, b):
return a + b
mapper['sum'] = sum
assert mapper['sum'] == sum
assert mapper['func'](1, 2) == 3
assert mapper['sum']([1, 2]) == 3
try:
mapper['sum'] = max
except TypeError as e:
assert e
else:
raise AssertionError()
test_user = 'username'
test_pwd = 'password'
s = start_rpc_service(mapper, host='127.0.0.1', port=5454, uri='/rpc',
username=test_user, pwd=test_pwd)
client = RPCClient('http://127.0.0.1:5454/rpc', test_user, test_pwd)
assert client.call('sum', [1, 2, 3]) == 6
assert client.call('func', a=100, b=1.1) == 101.1
try:
client.call('sum', 1, 1)
except RPCExecutionError as e:
assert 'TypeError' in str(e)
else:
raise AssertionError()
threads = []
for i in range(100):
t = RPCExecutionThread(target=client.call, args=('func', 0, i))
t.setName(str(i))
threads.append(t)
threads.append(
RPCExecutionThread(
target=client.call_with_another_credential,
args=('faked', 'faked', 'func', 0, i)
)
)
for t in threads:
t.start()
for i in range(10):
assert not client.handshake('faked', 'faked')
results = set()
exceptions = []
for t in threads:
t.join()
if t.result is not None:
results.add(t.result)
if t.exception:
exceptions.append(t.exception)
assert len(results) == 100
assert len(exceptions) == 100
try:
client.call('nothisfunc', 1, 2)
except RPCExecutionError as e:
assert 'Not found' in str(e)
else:
raise AssertionError()
try:
client2 = RPCClient('http://127.0.0.1:5454/rpc', 'badusername', 'wrongpwd')
assert not client2.handshake()
client2.call('func', 1, 2)
except RPCExecutionError as e:
assert 'authorization' in str(e)
else:
raise AssertionError()
assert ping_rpc_url('http://127.0.0.1:5454/rpc')
assert ping_rpc_url('http://127.0.0.1:5454')
stop_rpc_service(s)
assert not ping_rpc_url('http://127.0.0.1:5454/rpc')
client = RPCClient('http://127.0.0.1:1234/rpc', test_user, test_pwd)
try:
client.set_timeout(0.1)
client.call('func', 1, 2)
except RPCConnectionError:
pass
else:
raise AssertionError()