import pytest
import torch
import mindspore.nn as nn
from mindspore.nn.cell import Cell
from ascend_utils.mindspore.knowledge_distill.distill_losses_func_ms import DISTILL_LOSS_FUNC_MS
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig
class TestPruneKnowledgeDistillConfig(object):
def test_distill_config_given_valid_when_any_then_pass(self):
distill_config = KnowledgeDistillConfig()
distill_config.set_hard_label(0.5, 0) \
.add_inter_soft_label({
"t_module": "uniter.encoder.encoder.blocks.11.output",
"s_module": "uniter.encoder.encoder.blocks.5.output",
"t_output_idx": 0,
"s_output_idx": 0,
"loss_func": [{"func_name": "KDCrossEntropy",
"func_weight": 1}],
"shape": [2048]
}).add_output_soft_label({
"t_output_idx": 0,
"s_output_idx": 0,
"loss_func": [{"func_name": "KDCrossEntropy",
"func_weight": 1,
"temperature": 1,
"func_param": []}],
}).set_teacher_train() \
.add_custom_loss_func("test_loss_function", nn.Cell())
def test_distill_config_given_invalid_when_any_then_error(self):
distill_config = KnowledgeDistillConfig()
with pytest.raises(TypeError):
distill_config.set_hard_label("1", 1)
with pytest.raises(ValueError):
distill_config.add_inter_soft_label({
"t_module": "uniter.encoder.encoder.blocks.11.output"
})
with pytest.raises(ValueError):
distill_config.add_inter_soft_label({
"t_module": "uniter.encoder.encoder.blocks.11.output",
"s_module": "uniter.encoder.encoder.blocks.5.output",
"t_output_idx": 0,
"s_output_idx": 0,
"loss_func": [{}],
"shape": [2048]
})
with pytest.raises(TypeError):
distill_config.add_inter_soft_label({
"t_module": "uniter.encoder.encoder.blocks.11.output",
"s_module": "uniter.encoder.encoder.blocks.5.output",
"t_output_idx": 0,
"s_output_idx": 0,
"loss_func": [{"func_name": "KDCrossEntropy",
"func_weight": 1,
"func_param": 1}],
"shape": [2048]
})
with pytest.raises(ValueError):
distill_config.add_output_soft_label({
"t_output_idx": 0,
"s_output_idx": 0,
"loss_func": [{"func_name": "KDCrossEntropy",
"temperature": 1,
"func_param": []}],
})
with pytest.raises(TypeError):
distill_config.add_custom_loss_func("test_name", "nn.Cell()")
KnowledgeDistillConfig.check_config(distill_config, distill_config.custom_loss_func, DISTILL_LOSS_FUNC_MS,
Cell)