from dataset import Batch
from hybrid_torchrec import HashEmbeddingBagCollection, HashEmbeddingBagConfig
import torch
from torchrec import PoolingType
class TestModel(torch.nn.Module):
def __init__(self, table_names, feat_names, embed_dims, num_embeds):
super().__init__()
table_configs = []
for table_name, feat_name, dim, num_embed in zip(
table_names, feat_names, embed_dims, num_embeds
):
config = HashEmbeddingBagConfig(
name=table_name,
embedding_dim=dim,
num_embeddings=num_embed,
feature_names=feat_name,
pooling=PoolingType.SUM,
)
table_configs.append(config)
self.ebc = HashEmbeddingBagCollection(device="npu", tables=table_configs)
self.input_dim = sum([len(f) * d for f, d in zip(feat_names, embed_dims)])
self.linear_net = torch.nn.Linear(self.input_dim, self.input_dim)
def forward(self, batch: Batch):
result = self.ebc(batch.sparse_features)
result: torch.Tensor = result.values()
result = self.linear_net(result)
loss = result.mean() + result.sum() + result.max() + result.min()
return loss, result