prune

Function

Pruning function that configures various parameters during the pruning process and returns pruning information, which can be used to perform pruning during evaluation.

Prototype

prune(reserved_ratio=0.75, un_prune_list=None)

Parameters

Parameter Input/Return Description Constraints
reserved_ratio Input The retention ratio of pruning parameters. Optional.
Data type: Float.
Default value: 0.75. Value range: [0, 1].
un_prune_list Input Specifies layers that are not pruned. By default, the first and last layers are not pruned. Optional.
Data type: list. Elements must be int or string.
Default value: None.
If an element is int, it indicates the layer index not to be pruned (only Conv2d and Linear operators to be pruned are counted).
If an element is string, it indicates the name of the operator in the network.

Sample

from msmodelslim.pytorch.prune.prune_torch import PruneTorch
model = torchvision.models.vgg16(pretrained=False)
model.eval()
prune_torch = PruneTorch(model, torch.ones([1, 3, 224, 224]).type(torch.float32))
desc = prune_torch.prune(0.5)