import torch
from nssmpc import Party2PC, Party3PC
from nssmpc.infra.mpc import PartyCtx
from nssmpc.infra.tensor import RingTensor
from nssmpc.primitives.secret_sharing.arithmetic import SecretSharingScheme, AdditiveSecretSharing, \
ReplicatedSecretSharing
def SecretTensor(*, tensor: torch.Tensor = None,
src_id: int = None) -> AdditiveSecretSharing | ReplicatedSecretSharing | None:
"""Creates an arithmetic secret shared tensor.
Args:
tensor: The input tensor to be secret shared. If `src` is specified, this argument is ignored.
src_id: The source party ID from which to receive the secret shared tensor. If `None`, the tensor is secret shared from the local party.
Returns:
SecretSharingScheme: An arithmetic secret shared tensor.
Examples:
for data owner:
>>> share = SecretTensor(tensor=x)
for other parties:
>>> share = SecretTensor(src_id=0)
"""
party = PartyCtx.get()
if tensor is None and src_id is None:
raise ValueError("Either `tensor` or `src` must be provided.")
if isinstance(party, Party2PC):
if tensor is not None:
share_0, share_1 = AdditiveSecretSharing.share(RingTensor.convert_to_ring(tensor))
party.send(share_1)
return share_0
else:
return party.recv()
elif isinstance(party, Party3PC):
if tensor is not None:
share_0, share_1, share_2 = ReplicatedSecretSharing.share(RingTensor.convert_to_ring(tensor))
party.send((party.party_id + 1) % 3, share_1)
party.send((party.party_id + 2) % 3, share_2)
return share_0
elif src_id is not None:
return party.recv(src_id)
else:
raise RuntimeError("Unsupported party type for SecretTensor.")