from typing import Any
import fbgemm_gpu
import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import Tensor
INTERN_MODULE = "fbgemm_gpu.permute_pooled_embedding_modules"
FIXED_EXTERN_API = {
"PermutePooledEmbeddings": {
"__init__": ["self", "embs_dims", "permute", "device"],
"__call__": ["self", "pooled_embs"],
},
}
FWD_COMPAT_MSG = (
"WARNING: If this test is failing, you are probably trying "
"to make changes to a module that has been marked external to PyPer torch packages. "
"This can break forward compatibility of torch packages on training_platform "
"(see https://fb.workplace.com/groups/pyper/permalink/808155810065803/). "
"You need to split up your changes as follows:\n"
"\t1. Edit your diff so it only contains the changes as optional, and not any usage of the"
" new optional changes.\n"
"\t2. Edit FIXED_EXTERN_API in this test so your diff passes the test.\n"
"\t3. Land your diff and wait for the diff to be picked up by the production version of"
" fbpkg training_platform.\n"
"\t4. Once step 3. is complete, you can push the rest of your changes that use the new"
" changes."
)
class PermutePooledEmbeddingsFwdOnly(PermutePooledEmbeddings):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor:
result = torch.ops.fbgemm.permute_pooled_embs(
pooled_embs,
self._offset_dim_list.to(device=pooled_embs.device),
self._permute.to(device=pooled_embs.device),
self._inv_offset_dim_list.to(device=pooled_embs.device),
self._inv_permute.to(device=pooled_embs.device),
)
return result
class Net(torch.nn.Module):
def __init__(self, fwd_only: bool = False) -> None:
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(1, 10, bias=False)
op_cls = PermutePooledEmbeddingsFwdOnly if fwd_only else PermutePooledEmbeddings
self.permute_pooled_embeddings: PermutePooledEmbeddings = op_cls(
[2, 3, 1, 4],
[3, 0, 2, 1],
)
self.fc2 = torch.nn.Linear(10, 1, bias=False)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.permute_pooled_embeddings(x)
x = self.fc2(x)
return x