* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 and
* only version 2 as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*/
#include "pbl/pbl_uda.h"
#include "devdrv_pasid_rbtree.h"
#include "devdrv_mem_alloc.h"
#include "devdrv_ctrl.h"
#include "devdrv_util.h"
#ifdef CFG_FEATURE_DMA_COPY_SVA
#define DEVDRV_NON_TRANS_MSG_SVA_DESC_SIZE 0x400
struct devdrv_dma_pasid_rbtree_ctrl g_pasid_rbtree[MAX_DEV_CNT];
STATIC int devdrv_dma_pasid_node_insert(ka_task_spinlock_t *lock, ka_rb_root_t *root,
struct devdrv_dma_pasid_rbtree_node *dma_pasid_node)
{
ka_rb_node_t *parent = NULL;
ka_rb_node_t **new_node = NULL;
u64 new_node_hash, tree_hash;
ka_task_spin_lock_bh(lock);
new_node = ka_base_get_rb_root_node_addr(root);
new_node_hash = dma_pasid_node->hash_va;
while (*new_node) {
struct devdrv_dma_pasid_rbtree_node *this = ka_base_rb_entry(*new_node, struct devdrv_dma_pasid_rbtree_node, node);
parent = *new_node;
tree_hash = this->hash_va;
if (new_node_hash < tree_hash) {
new_node = &((*new_node)->rb_left);
} else if (new_node_hash > tree_hash) {
new_node = &((*new_node)->rb_right);
} else {
ka_task_spin_unlock_bh(lock);
return -EEXIST;
}
}
ka_base_rb_link_node(&dma_pasid_node->node, parent, new_node);
ka_base_rb_insert_color(&dma_pasid_node->node, root);
ka_task_spin_unlock_bh(lock);
return 0;
}
STATIC struct devdrv_dma_pasid_rbtree_node *devdrv_dma_pasid_node_search(ka_task_spinlock_t *lock, ka_rb_root_t *root,
u64 hash_va)
{
u64 tree_hash;
ka_rb_node_t *node = NULL;
struct devdrv_dma_pasid_rbtree_node *dma_pasid_node = NULL;
if (lock != NULL) {
ka_task_spin_lock_bh(lock);
}
node = ka_base_get_rb_root_node(root);;
while (node != NULL) {
dma_pasid_node = ka_base_rb_entry(node, struct devdrv_dma_pasid_rbtree_node, node);
tree_hash = dma_pasid_node->hash_va;
if (hash_va < tree_hash) {
node = node->rb_left;
} else if (hash_va > tree_hash) {
node = node->rb_right;
} else {
if (lock != NULL) {
ka_task_spin_unlock_bh(lock);
}
return dma_pasid_node;
}
}
if (lock != NULL) {
ka_task_spin_unlock_bh(lock);
}
return NULL;
}
STATIC void devdrv_dma_pasid_node_erase(ka_task_spinlock_t *lock, ka_rb_root_t *root,
struct devdrv_dma_pasid_rbtree_node *dma_pasid_node)
{
if (lock != NULL) {
ka_task_spin_lock_bh(lock);
}
ka_base_rb_erase(&dma_pasid_node->node, root);
if (lock != NULL) {
ka_task_spin_unlock_bh(lock);
}
}
STATIC int devdrv_pasid_rbtree_add(u32 dev_id, u64 pasid)
{
struct devdrv_dma_pasid_rbtree_node *node = NULL;
int ret;
if (dev_id >= MAX_DEV_CNT) {
devdrv_err("Invalid dev_id.(dev_id=%u)\n", dev_id);
return -EINVAL;
}
node = devdrv_kzalloc(sizeof(struct devdrv_dma_pasid_rbtree_node), KA_GFP_KERNEL | __KA_GFP_ACCOUNT);
if (node == NULL) {
devdrv_err("Alloc mem for rb_tree failed.(dev_id=%u).\n", dev_id);
return -ENOMEM;
}
node->dev_id = dev_id;
node->hash_va = pasid;
ret = devdrv_dma_pasid_node_insert(&g_pasid_rbtree[dev_id].rb_lock, &g_pasid_rbtree[dev_id].rbtree, node);
if (ret != 0) {
devdrv_err("Insert rbtree failed.(dev_id=%u; pasid=%llu)\n", dev_id, pasid);
devdrv_kfree(node);
return ret;
}
return 0;
}
STATIC int devdrv_pasid_rbtree_del(u32 dev_id, u64 pasid)
{
struct devdrv_dma_pasid_rbtree_node *dma_pasid_node = NULL;
struct devdrv_dma_pasid_rbtree_ctrl *tree = NULL;
if (dev_id >= MAX_DEV_CNT) {
devdrv_err("Invalid dev_id.(dev_id=%u)\n", dev_id);
return -EINVAL;
}
tree = &g_pasid_rbtree[dev_id];
ka_task_spin_lock_bh(&tree->rb_lock);
dma_pasid_node = devdrv_dma_pasid_node_search(NULL, &tree->rbtree, pasid);
if (dma_pasid_node == NULL) {
ka_task_spin_unlock_bh(&tree->rb_lock);
devdrv_err("Pasid is not in rbtree.(dev_id=%u; pasid=%llu)\n", dev_id, pasid);
return -ENOENT;
}
devdrv_dma_pasid_node_erase(NULL, &tree->rbtree, dma_pasid_node);
ka_task_spin_unlock_bh(&tree->rb_lock);
devdrv_kfree(dma_pasid_node);
return 0;
}
STATIC int devdrv_pasid_msg_edge_check(u32 devid, void *data, u32 in_data_len, u32 out_data_len, u32 *p_real_out_len)
{
u32 max_data_len = DEVDRV_NON_TRANS_MSG_SVA_DESC_SIZE - DEVDRV_NON_TRANS_MSG_HEAD_LEN;
struct devdrv_pasid_msg *recv_msg = NULL;
if ((devid >= MAX_DEV_CNT) || (data == NULL) || (p_real_out_len == NULL)) {
devdrv_err("Input parameter is error.(devid=%u)\n", devid);
return -EINVAL;
}
if ((in_data_len < sizeof(struct devdrv_pasid_msg)) || (in_data_len > max_data_len)) {
devdrv_err("Input parameter in_data_len is error.(devid=%u; in_data_len=0x%x)\n", devid, in_data_len);
return -EINVAL;
}
recv_msg = (struct devdrv_pasid_msg *)data;
if ((recv_msg->op_code >= (u32)DEVDRV_PASID_MAX) || (recv_msg->dev_id >= MAX_DEV_CNT)) {
devdrv_err("Pasid message channel recv message op_code or devid error.(dev_id=%u; msg_type=%u; msg_devid=%u;)"
")\n", devid, recv_msg->op_code, recv_msg->dev_id);
return -EINVAL;
}
return 0;
}
STATIC int devdrv_pasid_non_trans_msg_recv(void *msg_chan, void *data, u32 in_data_len,
u32 out_data_len, u32 *real_out_len)
{
u32 devid = (u32)devdrv_get_msg_chan_devid_inner(msg_chan);
struct devdrv_pasid_msg *recv_msg;
u32 op_code, index_id;
int ret;
if (devdrv_pasid_msg_edge_check(devid, data, in_data_len, out_data_len, real_out_len) != 0) {
devdrv_err("Input parameters check failed.(devid=%u)\n", devid);
return -EINVAL;
}
*real_out_len = sizeof(struct devdrv_pasid_msg);
recv_msg = (struct devdrv_pasid_msg *)data;
op_code = recv_msg->op_code;
(void)uda_udevid_to_add_id(recv_msg->dev_id, &index_id);
if (op_code == DEVDRV_PASID_ADD) {
ret = devdrv_pasid_rbtree_add(index_id, recv_msg->hash_va);
if (ret != 0) {
devdrv_err("Add node to rbtree failed.(index_id=%u; pasid=%llu)\n", index_id, recv_msg->hash_va);
}
} else {
ret = devdrv_pasid_rbtree_del(index_id, recv_msg->hash_va);
if (ret != 0) {
devdrv_err("Del node from rbtree failed.(index_id=%u; pasid=%llu)\n", index_id, recv_msg->hash_va);
}
}
recv_msg->error_code = ret;
return 0;
}
STATIC void devdrv_pasid_rbtree_clear(u32 dev_id)
{
struct devdrv_dma_pasid_rbtree_node *dma_pasid_node = NULL;
struct devdrv_dma_pasid_rbtree_ctrl *tree = NULL;
ka_rb_node_t *node = NULL;
tree = &g_pasid_rbtree[dev_id];
ka_task_spin_lock_bh(&tree->rb_lock);
while((node = ka_base_rb_first(&tree->rbtree)) != NULL) {
dma_pasid_node = ka_base_rb_entry(node, struct devdrv_dma_pasid_rbtree_node, node);
devdrv_dma_pasid_node_erase(NULL, &tree->rbtree, dma_pasid_node);
ka_task_spin_unlock_bh(&tree->rb_lock);
devdrv_kfree(dma_pasid_node);
ka_task_spin_lock_bh(&tree->rb_lock);
}
ka_task_spin_unlock_bh(&tree->rb_lock);
return;
}
STATIC struct devdrv_non_trans_msg_chan_info devdrv_pasid_non_trans_msg_chan_info = {
.msg_type = devdrv_msg_client_pasid,
.flag = 0,
.level = DEVDRV_MSG_CHAN_LEVEL_HIGH,
.s_desc_size = DEVDRV_NON_TRANS_MSG_SVA_DESC_SIZE,
.c_desc_size = DEVDRV_NON_TRANS_MSG_SVA_DESC_SIZE,
.rx_msg_process = devdrv_pasid_non_trans_msg_recv,
};
#endif
bool devdrv_dma_pasid_valid_check(u32 dev_id, u64 pasid, int env_boot_mode)
{
#ifdef CFG_FEATURE_DMA_COPY_SVA
struct devdrv_dma_pasid_rbtree_node *dma_pasid_node = NULL;
struct devdrv_dma_pasid_rbtree_ctrl *tree = NULL;
if (dev_id >= MAX_DEV_CNT) {
devdrv_err_spinlock("Invalid dev_id.(dev_id=%u)\n", dev_id);
return false;
}
if ((env_boot_mode == DEVDRV_MDEV_FULL_SPEC_VF_VM_BOOT) || (env_boot_mode == DEVDRV_MDEV_VF_VM_BOOT)) {
return true;
}
tree = &g_pasid_rbtree[dev_id];
dma_pasid_node = devdrv_dma_pasid_node_search(&tree->rb_lock, &tree->rbtree, pasid);
if (dma_pasid_node == NULL) {
devdrv_err_spinlock("Pasid is not in rbtree.(dev_id=%u; pasid=%llu)\n", dev_id, pasid);
return false;
}
#endif
return true;
}
void devdrv_pasid_rbtree_init(struct devdrv_pci_ctrl *pci_ctrl)
{
#ifdef CFG_FEATURE_DMA_COPY_SVA
g_pasid_rbtree[pci_ctrl->dev_id].rbtree = KA_RB_ROOT;
ka_task_spin_lock_init(&g_pasid_rbtree[pci_ctrl->dev_id].rb_lock);
#endif
return;
}
void devdrv_pasid_rbtree_uninit(struct devdrv_pci_ctrl *pci_ctrl)
{
#ifdef CFG_FEATURE_DMA_COPY_SVA
devdrv_pasid_rbtree_clear(pci_ctrl->dev_id);
#endif
return;
}
int devdrv_pasid_non_trans_init(struct devdrv_pci_ctrl *pci_ctrl)
{
#ifdef CFG_FEATURE_DMA_COPY_SVA
void *msg_chan = NULL;
if (pci_ctrl->virtfn_flag == DEVDRV_SRIOV_TYPE_VF) {
return 0;
}
msg_chan = devdrv_pcimsg_alloc_non_trans_queue_inner(pci_ctrl->dev_id, &devdrv_pasid_non_trans_msg_chan_info);
if (msg_chan == NULL) {
devdrv_err("Alloc common msg_queue failed, msg_chan is null. (dev_id=%u)\n", pci_ctrl->dev_id);
return -EINVAL;
}
g_pasid_rbtree[pci_ctrl->dev_id].msg_chan = (struct devdrv_msg_chan *)msg_chan;
#endif
return 0;
}
void devdrv_pasid_non_trans_uninit(struct devdrv_pci_ctrl *pci_ctrl)
{
#ifdef CFG_FEATURE_DMA_COPY_SVA
int ret;
if (pci_ctrl->virtfn_flag == DEVDRV_SRIOV_TYPE_VF) {
return;
}
ret = devdrv_pcimsg_free_non_trans_queue_inner((void *)g_pasid_rbtree[pci_ctrl->dev_id].msg_chan);
if (ret != 0) {
devdrv_info("No need to free common msg_queue. (dev_id=%u; ret=%d)\n", pci_ctrl->dev_id, ret);
}
g_pasid_rbtree[pci_ctrl->dev_id].msg_chan = NULL;
#endif
return;
}