prune_model_weight

Function

Model pruning API. Based on the original weights, a pruned Transformer model instance loaded with smaller parameters, and the pruning configuration passed to the API, it prunes the original weights and loads the pruned weights into the model instance with smaller parameters.

Prototype

prune_model_weight(model, config, weight_file_path)

Parameters

Parameter Input/Return Description Constraints
model Input The model instance after pruning. Required.
Data type: MindSpore or PyTorch model.
config Input The pruning configuration. Required.
Data type: PruneConfig object.
weight_file_path Input The path and file name of the original model weight file before pruning. Required.
Data type: string
The weight file for a MindSpore model must be in ckpt format, and the weight file for a PyTorch framework must be in pt/pth/pkl/bin format.

Sample

from msmodelslim.common.prune.transformer_prune.prune_model import PruneConfig
from msmodelslim.common.prune.transformer_prune.prune_model import prune_model_weight
# Define the configuration class.
prune_config = PruneConfig()
prune_config.set_steps(['prune_blocks', 'prune_bert_intra_block']). \
  add_blocks_params('uniter\.encoder\.encoder\.blocks\.(\d+)\.', {0: 1, 1: 3, 2: 5, 3: 7, 4: 9, 5: 11})
# Pass in parameters to prune the model.
prune_model_weight(model, prune_config, weight_file_path = "xxx.ckpt")