"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import torch
import torch.nn as nn
def conv_bn_relu(in_channel, out_channel, kernel_size, stride):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size, stride),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
class TestAscendQuantModel(nn.Module):
def __init__(self):
super(TestAscendQuantModel, self).__init__()
self.first_conv = nn.Conv2d(1, 1, 3)
self.left_conv = nn.Conv2d(1, 1, 3)
self.right_conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.first_conv(x)
x1 = self.left_conv(x)
x2 = self.right_conv(x)
y = torch.cat((x1, x2))
return y
class TestNet(nn.Module):
"""
TestNet
"""
def __init__(self, class_num=10):
super(TestNet, self).__init__()
self.network = conv_bn_relu(3, 32, 3, 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, class_num)
def forward(self, x):
x = x
x = self.network(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
class TestNet2(nn.Module):
"""
TestNet
"""
def __init__(self, class_num=10):
super(TestNet2, self).__init__()
self.backbone = conv_bn_relu(3, 32, 3, 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, class_num)
def forward(self, x):
x = x
x = self.backbone(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
class TestNet3(nn.Module):
"""
TestNet
"""
def __init__(self, class_num=10):
super().__init__()
self.network = conv_bn_relu(3, 32, 3, 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, class_num)
def forward(self, x):
x2 = x
x = self.network(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x, x2
class TestOnnxQuantModel(nn.Module):
def __init__(self, class_num=10):
super().__init__()
self.conv_list = nn.ModuleList([conv_bn_relu(3, 32, 3, 1),
conv_bn_relu(32, 32, 3, 1),
conv_bn_relu(32, 32, 3, 1),
conv_bn_relu(32, 32, 3, 1)])
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.linear = nn.Linear(32, class_num)
def forward(self, input_x):
for conv in self.conv_list:
input_x = conv(input_x)
input_x = self.avg_pool(input_x)
input_x = torch.flatten(input_x, 1)
output = self.linear(input_x)
return output
def get_model():
return TestNet(class_num=10)
class LrdSampleNetwork(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Sequential(
nn.Embedding(16, 32),
nn.Linear(32, 64),
nn.ReLU(inplace=True),
)
self.feature = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, 1),
nn.ReLU(inplace=True),
)
self.pool = nn.AdaptiveAvgPool2d((5, 5))
self.inner = nn.Linear(64 * 5 * 5, 512)
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.Linear(256, 10),
)
def forward(self, inputs):
shortcut = self.embedding(inputs)
shortcut = shortcut.permute([0, 3, 1, 2])
next_node = self.feature(shortcut)
next_node = next_node + shortcut
next_node = self.pool(next_node)
next_node = torch.flatten(next_node, 1)
next_node = self.inner(next_node)
next_node = self.classifier(next_node)
return next_node
class TorchTeacherModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.teacher_fc = torch.nn.Linear(1, 1)
def forward(self, inputs):
output = self.teacher_fc(inputs)
return output
class TorchStudentModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.student_fc = torch.nn.Linear(1, 1)
def forward(self, inputs):
output = self.student_fc(inputs)
return output
class GroupLinearTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cpu')
self.config = None
self.dtype = torch.float16
self.l1 = torch.nn.Linear(256, 256, bias=False)
self.l2 = torch.nn.Linear(256, 256, bias=False)
def forward(self, x):
x = self.l1(x)
x = torch.nn.functional.relu(x)
x = self.l2(x)
return x
class TwoLinearTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cpu')
self.config = None
self.dtype = torch.float16
self.l1 = torch.nn.Linear(8, 8, bias=False)
self.l2 = torch.nn.Linear(8, 8, bias=False)
def forward(self, x):
x = self.l1(x)
x = torch.nn.functional.relu(x)
x = self.l2(x)
return x
class ThreeLinearTorchModel_for_Sparse(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cpu')
self.config = None
self.dtype = torch.float16
self.l1 = torch.nn.Linear(256, 256, bias=False)
self.l2 = torch.nn.Linear(256, 256, bias=False)
self.l3 = torch.nn.Linear(256, 256, bias=False)
def forward(self, x):
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
return x
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim))
def forward(self, x):
variance = torch.mean(x * x, dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.g * x
class AttentionTorchModel(nn.Module):
def __init__(self, embed_dim=32, num_heads=8):
super().__init__()
self.device = torch.device('cpu')
self.config = None
self.dtype = torch.float16
self.embed_dim = embed_dim
self.num_heads = num_heads
self.input_norm = RMSNorm(embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.o = nn.Linear(embed_dim, embed_dim)
self.post_norm = RMSNorm(embed_dim)
self.scale = embed_dim ** -0.5
def forward(self, hidden_states, past_key_value=None, mask=None):
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.input_norm(hidden_states)
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(-1, self.num_heads, self.embed_dim // self.num_heads)
key = key.view(-1, self.num_heads, self.embed_dim // self.num_heads)
value = value.view(-1, self.num_heads, self.embed_dim // self.num_heads)
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=0)
value = torch.cat([past_key_value[1], value], dim=0)
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -float('inf'))
attn_weights = nn.functional.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, value)
output = output.view(-1, self.num_heads * self.embed_dim // self.num_heads)
output = self.o(output)
output = self.post_norm(output)
past_key_value = (key, value)
return output, past_key_value
class SophonRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
"""
SophonRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
return self.weight * x
class SophonTorchAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.scale = self.embed_dim ** -0.5
def forward(self, hidden_states, past_key_value=None, mask=None, **kargs):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(-1, self.num_heads, self.embed_dim // self.num_heads)
key = key.view(-1, self.num_heads, self.embed_dim // self.num_heads)
value = value.view(-1, self.num_heads, self.embed_dim // self.num_heads)
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=0)
value = torch.cat([past_key_value[1], value], dim=0)
past_key_value = (key, value)
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -float('inf'))
attn_weights = nn.functional.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, value)
output = output.view(-1, self.num_heads * self.embed_dim // self.num_heads)
output = self.o_proj(output)
return output, past_key_value
class SophonTorchMlp(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.gate_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.gate_proj2 = nn.Linear(embed_dim, embed_dim, bias=False)
self.act1 = nn.ReLU(inplace=True)
self.down_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.up_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.act2 = nn.ReLU(inplace=True)
def forward(self, hidden_states):
return self.down_proj((self.act2(self.gate_proj(hidden_states)) +
self.act1(self.gate_proj2(hidden_states))) * self.up_proj(hidden_states))
class SophonTorchDecoder(nn.Module):
def __init__(self, embed_dim=256, num_heads=8):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.input_norm = SophonRMSNorm(self.embed_dim)
self.attn = SophonTorchAttention(self.embed_dim, self.num_heads)
self.mlp = SophonTorchMlp(self.embed_dim)
self.post_norm = SophonRMSNorm(self.embed_dim)
def forward(self, hidden_states, attention_mask, rotary_pos_emb_list, use_cache=False):
residual = hidden_states
hidden_states = self.input_norm(hidden_states)
hidden_states, past_key_value = self.attn(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, past_key_value
class AttentionTorchSophonModel(nn.Module):
def __init__(self, embed_dim=256, num_heads=8):
super().__init__()
self.device = torch.device('cpu')
self.config = None
self.dtype = torch.float16
self.embed_dim = embed_dim
self.num_heads = num_heads
layer_list = []
layer_list.append(SophonTorchDecoder(self.embed_dim, self.num_heads))
self.layers = nn.ModuleList(layer_list)
self.norm = SophonRMSNorm(self.embed_dim)
def forward(self, hidden_states, use_cache=False):
for _, decoder_layer in enumerate(self.layers):
attention_mask = 1
rotary_pos_emb_list = 1
hidden_states, past_key_value = decoder_layer(hidden_states, attention_mask,
rotary_pos_emb_list, use_cache=use_cache)
hidden_states = self.norm(hidden_states)
return hidden_states, past_key_value
class ExpertFFN(nn.Module):
def __init__(self, embed_dim=32):
super(ExpertFFN, self).__init__()
self.act_fn = RMSNorm(embed_dim)
self.w1 = nn.Linear(embed_dim, embed_dim, bias=False)
self.w2 = nn.Linear(embed_dim, embed_dim, bias=False)
self.w3 = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MOEModel(nn.Module):
"""
MOE-like model
"""
def __init__(self, embed_dim=32):
super(MOEModel, self).__init__()
self.expert1 = ExpertFFN(embed_dim)
self.expert2 = ExpertFFN(embed_dim)
self.dtype = torch.float16
self.device = torch.device('cpu')
def forward(self, x):
output = self.expert1(x)
return output