from torch.utils.data import DataLoader
NssMPC Library Tutorial - Neural Network Inference with Secret Sharing
Overview
This tutorial demonstrates how to perform privacy-preserving neural network inference using NssMPC. We support both 2-party and 3-party scenarios. The model owner holds a pre-trained neural network model, and data owners hold private input data. Through secret sharing techniques, data owners can obtain inference results without revealing their input data, and the model owner's weights remain confidential.
Prerequisites
Before running the Neural Network Inference with NssMPClib, ensure you have completed the setup as described in the Getting Started Guide and have the NssMPC library installed.
You should also have access to the pre-trained model weights and the dataset for inference. In this tutorial, we will use the AlexNet model trained on the CIFAR-10 dataset.
First of all, match the security model and neural network model you want to use by modifying the scripts below (e.g., changing from SEMI_HONEST to HONEST_MAJORITY, or replacing AlexNet with ResNet) and run the script to generate Beaver triples for neural network inference:
python scripts/neural_network_beaver_generation.py
Then we can proceed with the tutorial.
Important: Separate Execution
All parties must run in completely separate Python processes/scripts. This ensures data privacy and proper network communication.
2-Party Neural Network Inference
Party 0 Setup (Model Owner)
# File: party0_2pc.py - MUST be run separately
import torch
import nssmpc.application.neural_network as nn
from nssmpc import PartyRuntime, Party2PC, SEMI_HONEST
from nssmpc.config import NN_path
from data.AlexNet.Alexnet import AlexNet
if __name__ == '__main__':
party = Party2PC(0, SEMI_HONEST)
party.online()
with PartyRuntime(party):
plaintext_model = AlexNet()
plaintext_model.load_state_dict(torch.load(NN_path + 'AlexNet_CIFAR10.pkl'))
shared_param = nn.utils.share_model_param(model=plaintext_model)
SecAlexNet = nn.utils.convert_model(AlexNet)
ciphertext_model = SecAlexNet()
ciphertext_model = nn.utils.load_shared_param(ciphertext_model, shared_param)
shared_data_loader = nn.utils.SharedDataLoader(src_id=1)
for data in shared_data_loader:
secret_result = ciphertext_model(data)
secret_result.recon(target_id=1)
party.close()
Party 1 Setup (Data Owner)
# File: party1_2pc.py - MUST be run separately
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import nssmpc.application.neural_network as nn
from nssmpc import PartyRuntime, SEMI_HONEST, Party2PC
from nssmpc.config import NN_path
from nssmpc.infra.utils.profiling import RuntimeTimer
from data.AlexNet.Alexnet import AlexNet
if __name__ == '__main__':
party = Party2PC(2, SEMI_HONEST)
party.online()
with PartyRuntime(party):
transform1 = transforms.Compose([transforms.ToTensor()])
test_set = torchvision.datasets.CIFAR10(root=NN_path, train=False, download=True, transform=transform1)
indices = list(range(1024))
subset_data = Subset(test_set, indices)
test_loader = torch.utils.data.DataLoader(subset_data, batch_size=1, shuffle=False, num_workers=0)
shared_param = nn.utils.share_model_param(src_id=0)
SecAlexNet = nn.utils.convert_model(AlexNet)
ciphertext_model = SecAlexNet()
ciphertext_model = nn.utils.load_shared_param(ciphertext_model, shared_param)
shared_data_loader = nn.utils.SharedDataLoader(data_loader=test_loader)
correct_total = 0
total_total = 0
for data in shared_data_loader:
correct = 0
total = 0
inputs, labels = data
with RuntimeTimer(tag="Inference", enable_comm_stats=True):
secret_result = ciphertext_model(inputs)
plaintext_result = secret_result.recon(target_id=1).convert_to_real_field()
_, predicted = torch.max(plaintext_result, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
total_total += total
correct_total += correct
print('Accuracy of the communication on test images:{}%'.format(100 * correct / total))
print('Accuracy of the communication on test images:{}%'.format(100 * correct_total / total_total))
party.close()
3-Party Neural Network Inference
Party 0 Setup (Model Owner)
# File: party0_3pc.py - MUST be run separately
import torch.utils.data
import nssmpc.application.neural_network as nn
from nssmpc import Party3PC, SEMI_HONEST, PartyRuntime, SecretTensor
from nssmpc.infra.utils.profiling import RuntimeTimer
from data.AlexNet.Alexnet import AlexNet
if __name__ == '__main__':
P = Party3PC(0, SEMI_HONEST)
P.online()
with PartyRuntime(P):
test_input = torch.randint(-10, 10, [1, 3, 32, 32]) * 1.0
print("test_input:", test_input)
plaintext_model = AlexNet()
test_output = plaintext_model(test_input)
print("test_output", test_output)
# Share model parameters
shared_param = nn.utils.share_model_param(model=plaintext_model)
# Convert to secure model class
SecAlexNet = nn.utils.convert_model(AlexNet)
# Instantiate secure model
ciphertext_model = SecAlexNet()
# Load shared parameters
net = nn.utils.load_shared_param(ciphertext_model, shared_param)
# Share input data
share_input = SecretTensor(tensor=test_input)
# Inference and profiling
with RuntimeTimer(enable_comm_stats=True):
output = net(share_input)
# Reconstruct output to Party 0 and print
print("output", output.recon(target_id=0).convert_to_real_field())
P.close()
Party 1 Setup (Computation Party)
# File: party1_3pc.py - MUST be run separately
import nssmpc.application.neural_network as nn
from nssmpc import Party3PC, SEMI_HONEST, PartyRuntime, SecretTensor
from data.AlexNet.Alexnet import AlexNet
if __name__ == '__main__':
P = Party3PC(1, SEMI_HONEST)
P.online()
with PartyRuntime(P):
# Receive weights
local_param = nn.utils.share_model_param(src_id=0)
# Convert to secure model class
SecAlexNet = nn.utils.convert_model(AlexNet)
# Instantiate secure model
ciphertext_model = SecAlexNet()
# Load weights
ciphertext_model = nn.utils.load_shared_param(ciphertext_model, local_param)
# Receive input
share_input = SecretTensor(src_id=0)
# Inference
output = ciphertext_model(share_input)
# Reconstruct output to Party 0
output.recon(target_id=0)
P.close()
Party 2 Setup (Computation Party)
# File: party2_3pc.py - MUST be run separately
import nssmpc.application.neural_network as nn
from nssmpc import Party3PC, SEMI_HONEST, PartyRuntime, SecretTensor
from data.AlexNet.Alexnet import AlexNet
if __name__ == '__main__':
P = Party3PC(2, SEMI_HONEST)
P.online()
with PartyRuntime(P):
# Receive weights
local_param = nn.utils.share_model_param(src_id=0)
# Convert to secure model class
SecAlexNet = nn.utils.convert_model(AlexNet)
# Instantiate secure model
ciphertext_model = SecAlexNet()
# Load weights
ciphertext_model = nn.utils.load_shared_param(ciphertext_model, local_param)
# Receive input
share_input = SecretTensor(src_id=0)
# Inference
output = ciphertext_model(share_input)
# Reconstruct output to Party 0
output.recon(target_id=0)
P.close()
Key Components Explained
1. Model Conversion (nn.utils.convert_model)
To convert a standard PyTorch model to a secure model compatible with NssMPC:
import nssmpc.application.neural_network as nn
SecModelClass = nn.utils.convert_model(PlaintextModelClass) # e.g., AlexNet
ciphertext_model = SecModelClass()
Or build the model manually using NssMPC layers:
from torch.nn import Module
import nssmpc.application.neural_network as nn
from nssmpc.application.neural_network.layers import SecConv2d, SecReLU, SecLinear
class CustomModel(Module):
def __init__(self):
super(CustomModel, self).__init__()
self.conv1 = SecConv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = SecReLU()
self.fc = SecLinear(16 * 32 * 32, 10)
# Add more layers as needed
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
2. Model Sharing (nn.utils.share_model_param)
For the model owner to share model parameters with the data owner:
import nssmpc.application.neural_network as nn
shared_param = nn.utils.share_model_param(model=plaintext_model)
For the data owner to receive shared parameters:
import nssmpc.application.neural_network as nn
shared_param = nn.utils.share_model_param(src_id=MODEL_OWNER_ID)
3. Data Sharing
Single Input Sharing
For the data owner to share input data:
from nssmpc import SecretTensor
share_input = SecretTensor(tensor=test_input)
For the model owner or other parties to receive shared input data:
from nssmpc import SecretTensor
share_input = SecretTensor(src_id=DATA_OWNER_ID) # Party 1 & 2
DataLoader Sharing
For sharing a DataLoader:
import nssmpc.application.neural_network as nn
shared_data_loader = nn.utils.SharedDataLoader(data_loader=dataloader) # dataloader from PyTorch
For receiving shared DataLoader:
import nssmpc.application.neural_network as nn
shared_data_loader = nn.utils.SharedDataLoader(src_id=DATA_OWNER_ID)
Then iterate over shared_data_loader to get secret-shared batches.
4. Privacy-Preserving Inference
To perform inference on secret-shared data, just call the model as usual:
secret_result = ciphertext_model(share_input)
Key Points to Remember
For 2-Party Setup:
- Separate Execution: Parties in separate processes
- Model Consistency: Both parties use identical model architecture
- Data Privacy: Other participants besides the data owner NEVER see plaintext data
For 3-Party Setup:
- Three Separate Processes: Each party in its own script
- Consistent Mode: All parties use same security mode (SEMI_HONEST or HONEST_MAJORITY)
- Model Structure: All parties initialize same model architecture
Security Guarantees:
- Input Privacy: Data owner's input remains private
- Model Privacy: Model owner's weights remain private
- Output Privacy: Only authorized parties learn inference results
Execution Instructions
2-Party Inference:
- Terminal 1:
python party0_2pc.py - Terminal 2:
python party1_2pc.py
3-Party Inference:
- Terminal 1:
python party0_3pc.py(Model Owner) - Terminal 2:
python party1_3pc.py(Computation Party) - Terminal 3:
python party2_3pc.py(Computation Party)
Customization Guide
Using Different Models
import torch
import nssmpc.application.neural_network as nn
from nssmpc import PartyRuntime, SecretTensor
class CustomModel(torch.nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.layer1 = torch.nn.Linear(784, 256)
self.layer2 = torch.nn.Linear(256, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(self.layer1(x))
x = self.layer2(x)
return x
with PartyRuntime(party):
# Receive shared parameters and load into custom model
shared_param = nn.utils.share_model_param(src_id=0)
net = CustomModel()
ciphertext_model = nn.utils.convert_model(CustomModel)
# or use Modules in nssmpc.application.neural_network.layers to build model manually
# Load shared parameters
nn.utils.load_shared_param(ciphertext_model, shared_param)
# Share input data and perform inference as shown earlier
...
Error Handling
- Connection Issues: Ensure all parties are online and reachable
- Model Mismatch: Verify the same architecture is used
- Memory Issues: Reduce batch size if out of memory
- Mode Consistency: All parties must use same security mode
Performance Considerations
- 2-Party vs 3-Party: 3-party provides stronger security but more overhead
- Batch Size: Larger batches improve throughput
- Network Latency: Consider network speed between parties
- Model Complexity: Larger models require more computation
Choose 2-party for efficiency when only two entities are involved, or 3-party for stronger security guarantees with honest majority assumption.