"""
function_secret_sharing 测试
"""
import unittest
from nssmpc.infra.tensor import RingTensor
from nssmpc.primitives.secret_sharing.function import prefix_parity_query, DCF, DPF, SigmaDICF, GrottoDICF, DICF
num_of_keys = 10
alpha = RingTensor(5)
beta = RingTensor(1)
down_bound = RingTensor(3)
upper_bound = RingTensor(7)
x = RingTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
class Test(unittest.TestCase):
def test_dcf(self):
key0, key1 = DCF.gen(num_of_keys=num_of_keys, alpha=alpha, beta=beta)
res_0 = DCF.eval(x=x.view(-1, 1), keys=key0, party_id=0)
res_1 = DCF.eval(x=x.view(-1, 1), keys=key1, party_id=1)
res = res_0 + res_1
res = res.view(x.shape)
print(res)
assert (res == RingTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0], device=res.device)).all()
def test_dpf(self):
key0, key1 = DPF.gen(num_of_keys=num_of_keys, alpha=alpha, beta=beta)
res_0 = DPF.eval(x=x.view(-1, 1), keys=key0, party_id=0)
res_1 = DPF.eval(x=x.view(-1, 1), keys=key1, party_id=1)
res = res_0 + res_1
res = res.view(x.shape)
print(res)
assert (res == RingTensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0], device=res.device)).all()
def test_ppq(self):
key0, key1 = DPF.gen(num_of_keys=num_of_keys, alpha=alpha, beta=beta)
res_0 = prefix_parity_query(x.view(-1, 1), key0, party_id=0)
res_1 = prefix_parity_query(x.view(-1, 1), key1, party_id=1)
res = res_0 ^ res_1
res = res.view(x.shape)
print(res)
assert (res == RingTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], device=res.device)).all()
def test_dicf(self):
key0, key1 = DICF.gen(num_of_keys=num_of_keys, down_bound=down_bound, upper_bound=upper_bound)
x_shift = x + key0.r_in.reshape(x.shape) + key1.r_in.reshape(x.shape)
res_0 = DICF.eval(x_shift=x_shift, keys=key0, party_id=0, down_bound=down_bound, upper_bound=upper_bound)
res_1 = DICF.eval(x_shift=x_shift, keys=key1, party_id=1, down_bound=down_bound, upper_bound=upper_bound)
res = res_0 + res_1
print(res)
assert (res == RingTensor([0, 0, 1, 1, 1, 1, 1, 0, 0, 0], device=res.device)).all()
def test_grotto(self):
key0, key1 = GrottoDICF.gen(num_of_keys=num_of_keys, beta=beta)
x_shift = key0.r_in.reshape(x.shape) + key1.r_in.reshape(x.shape) - x
res_0 = GrottoDICF.eval(x_shift=x_shift, key=key0, party_id=0, down_bound=down_bound,
upper_bound=upper_bound)
res_1 = GrottoDICF.eval(x_shift=x_shift, key=key1, party_id=1, down_bound=down_bound,
upper_bound=upper_bound)
res = res_0 ^ res_1
print(res)
assert (res == RingTensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0], device=res.device)).all()
def test_sigma(self):
key0, key1 = SigmaDICF.gen(num_of_keys=num_of_keys)
x_shift = key0.r_in.reshape(x.shape) + key1.r_in.reshape(x.shape) + x - 5
res_0 = SigmaDICF.eval(x_shift=x_shift, key=key0, party_id=0)
res_1 = SigmaDICF.eval(x_shift=x_shift, key=key1, party_id=1)
res = res_0 ^ res_1
print(res)
assert (res == RingTensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1], device=res.device)).all()