模型接口
get_embedding_table
创建稀疏表。
def get_embedding_table(
name: str,
dimension: int,
device_vocabulary_size: int,
initializer: Union[InitializerV1, InitializerV2] = tf.compat.v1.random_normal_initializer(),
key_dtype: tf.DType = tf.int64,
value_dtype: tf.DType = tf.float32,
distribution_strategy: str = EmbDistributionStrategy.MP.value,
min_used_times: Optional[int] = None,
max_cold_secs: Optional[int] = None,
):
| 参数名 | 类型 | 可选/必选 | 说明 |
|---|---|---|---|
| name | str | 必选 | 稀疏表名,只能包含[0-9A-Za-z_.],表名长度范围:[1, 128]。 |
| dimension | int | 必选 | 稀疏表的embedding维度,取值范围:[1, 512]。 |
| device_vocabulary_size | int | 必选 | Device侧稀疏表容量,取值范围:[1, 10**9]。请保证内存和磁盘空间足够,根据服务器的实际配置进行设置。 |
| initializer | Tensorflow的初始化器类型 | 可选 | 稀疏表初始值生成器,默认值为随机正态分布初始化器。 |
| key_dtype | tf.int64 | 可选 | 稀疏特征key数据类型,默认值为tf.int64,可选类型仅限于tf.int64。 |
| value_dtype | tf.float32 | 可选 | 稀疏特征value数据类型,默认值为tf.float32,可选类型仅限于tf.float32。 |
| distribution_strategy | str | 可选 | 稀疏表分布式并行模式,默认值为"MP"(模型并行),当前仅支持"MP"。 |
| min_used_times | int / None | 可选 | 特征准入功能,ID历史出现次数必须超过该值才能生效,否则使用该ID查表返回默认的Embedding值,取值范围[0, 2^31-1]。 |
| max_cold_secs | int / None | 可选 | 特征淘汰功能,该值用于表示ID和Embedding自最后一次被访问后,最久能够容忍的未访问时长。超过这个阈值的ID将会在保存时,从存储Embedding的哈希表中删除,取值范围[0, 2^64-1]。 |
说明
- 当前仅支持非扩容模式,即
device_vocabulary_size需大于0。 - 开启特征淘汰功能后暂不支持保存与加载。
- 成功:返回稀疏表实例
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
initializer=tf.truncated_normal_initializer(),
key_dtype=tf.int64,
value_dtype=tf.float32,
)
embedding_lookup
稀疏表特征查询接口。
暂不支持tf.SparseTensor数据类型,若为tf.SparseTensor需转成tf.Tensor。示例代码如下:
sparse_ids = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
dense_ids = tf.sparse.to_dense(sparse_ids, default_value=0)
def embedding_lookup(
emb_table: Union[StaticEmbTable],
ids: tf.Tensor,
):
| 参数名 | 类型 | 可选/必选 | 说明 |
|---|---|---|---|
| emb_table | 稀疏表实例 | 必选 | 稀疏表实例,通过get_embedding_table接口得到。 |
| ids | tf.Tensor | 必选 | 待查询的关键字(key)。 |
- 成功:返回查询到Tensor结果
- 失败:抛出异常
import mxrec
import tensorflow as tf
# Mxrec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
)
# Embedding lookup.
ids = tf.convert_to_tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.int64)
embedding = mxrec.embedding_lookup(table, ids)
get_sparse_embedding
获取已创建稀疏表的可训练参数,用于优化器计算/更新梯度。
def get_sparse_embedding():
- 成功:返回已创建稀疏表的可训练参数列表
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
)
init_hashtable_op = mxrec.get_init_hashtable_op()
# The model's loss.
loss = ...
sparse_optimizer = mxrec.AdamWOptimizer(learning_rate=0.01)
sparse_embeddings = mxrec.get_sparse_embedding()
sparse_grads = tf.gradients(loss, sparse_embeddings)
train_ops = sparse_optimizer.apply_gradients(zip(sparse_grads, sparse_embeddings))
get_init_hashtable_op
获取稀疏表的初始化算子列表,计算图执行之前需先执行此算子列表。
def get_init_hashtable_op():
- 成功:稀疏表的初始化算子列表
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
initializer=tf.truncated_normal_initializer(),
key_dtype=tf.int64,
value_dtype=tf.float32,
)
init_hashtable_op = mxrec.get_init_hashtable_op()
with tf.compat.v1.Session as sess:
sess.run(init_hashtable_op)
sess.run(tf.compat.v1.global_variables_initializer())
get_existing_tables
获取当前所有已创建的稀疏表对象。
def get_existing_tables():
- 成功:当前所有已创建的稀疏表对象
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
initializer=tf.truncated_normal_initializer(),
key_dtype=tf.int64,
value_dtype=tf.float32,
)
# Get all created embedding table instance.
emb_tables = mxrec.get_existing_table()
EmbeddingTableSaver
EmbeddingTableSaver类被设计用于在基于Tensorflow框架的模型中管理稀疏表(NPU设备实现)的保存和恢复。它提供了一个方便的接口来保存和恢复稀疏表的数据。
__init__
初始化EmbeddingTableSaver对象。
def __init__(self, emb_tables: List[BaseEmbeddingTable]):
| 参数名 | 类型 | 可选/必选 | 说明 |
|---|---|---|---|
| emb_tables | List[BaseEmbeddingTable] | 必选 | 待执行保存恢复操作的稀疏表对象,通过get_existing_tables接口得到。 |
说明
BaseEmbeddingTable表示稀疏表实例,可通过get_embedding_table接口获取。
- 成功:None
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
initializer=tf.truncated_normal_initializer(),
key_dtype=tf.int64,
value_dtype=tf.float32,
)
# Get all created embedding table instance.
emb_tables = mxrec.get_existing_table()
emb_table_saver = mxrec.EmbeddingTableSaver(emb_tables)
save
执行稀疏表(NPU实现)的保存操作。
def save(self, sess: tf.compat.v1.Session, save_path: str, global_step: int):
| 参数名 | 类型 | 可选/必选 | 说明 |
|---|---|---|---|
| sess | tf.Session | 必选 | Tensorflow当前执行中的session。 |
| save_path | str | 必选 | 保存路径 |
| global_step | int | 必选 | 保存步数,取值范围[0, 2^32-1] |
- 成功:None
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
initializer=tf.truncated_normal_initializer(),
key_dtype=tf.int64,
value_dtype=tf.float32,
)
# Get all created embedding table instance.
emb_tables = mxrec.get_existing_table()
emb_table_saver = mxrec.EmbeddingTableSaver(emb_tables)
with tf.compat.v1.Session() as sess:
# run your session
emb_table_saver.save(sess, save_path, global_step)
load
执行稀疏表(NPU实现)的加载操作。
def load(self, sess: tf.compat.v1.Session, save_path: str, global_step: int):
| 参数名 | 类型 | 可选/必选 | 说明 |
|---|---|---|---|
| sess | tf.Session | 必选 | Tensorflow当前执行中的session。 |
| save_path | str | 必选 | 加载路径 |
| global_step | int | 必选 | 加载步数,取值范围[0, 2^32-1] |
- 成功:None
- 失败:抛出异常
import mxrec
import tensorflow as tf
# MxRec init.
mxrec.init("toml_path")
# Create an embedding table.
table = mxrec.get_embedding_table(
name="example_name",
dimension=8,
device_vocabulary_size=10000,
initializer=tf.truncated_normal_initializer(),
key_dtype=tf.int64,
value_dtype=tf.float32,
)
# Get all created embedding table instance.
emb_tables = mxrec.get_existing_table()
emb_table_saver = mxrec.EmbeddingTableSaver(emb_tables)
with tf.compat.v1.Session() as sess:
emb_table_saver.load(sess, save_path, global_step)
# run your session