* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 and
* only version 2 as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*/
#include "ka_base_pub.h"
#include "ka_common_pub.h"
#include "ka_compiler_pub.h"
#include "ka_ioctl_pub.h"
#include "ka_memory_pub.h"
#include "ka_kernel_def_pub.h"
#include "ascend_hal_define.h"
#include "rmo_auto_init.h"
#include "comm_kernel_interface.h"
#include "pbl_kernel_interface.h"
#include "pbl/pbl_uda.h"
#include "pbl/pbl_task_ctx.h"
#include "dpa/dpa_rmo_kernel.h"
#include "rmo_kern_log.h"
#include "rmo_ioctl.h"
#include "rmo_fops.h"
#include "rmo_sched.h"
#include "rmo_mem_sharing_ctx.h"
#include "rmo_mem_sharing.h"
#define RMO_MEM_SHARING_MAX_SIZE 4096
static struct task_ctx_domain *res_mem_sharing_ops_domain = NULL;
static inline void rmo_pack_mem_attr(u64 addr, u64 size, struct ka_mem_attr *mem_attr)
{
mem_attr->addr = addr;
mem_attr->size = size;
mem_attr->cp_only_flag = false;
mem_attr->raw_pa_flag = true;
}
static int rmo_mem_get_pa_list(u32 devid, u64 addr, u64 size, u64 *pa_list)
{
struct ka_pa_wraper pa_wraper = {0};
struct ka_mem_attr mem_attr = {0};
u64 pa_num = 1;
int ret = 0;
rmo_pack_mem_attr(addr, size, &mem_attr);
ret = hal_kernel_get_mem_pa_list(devid, ka_task_get_current_tgid(), &mem_attr, &pa_num, &pa_wraper);
if (ret == 0) {
*pa_list = pa_wraper.pa;
}
return ret;
}
static int rmo_mem_put_pa_list(u32 devid, u64 addr, u64 size, u64 *pa_list)
{
struct ka_pa_wraper pa_wraper = {.pa = pa_list[0], .size = size};
struct ka_mem_attr mem_attr = {0};
u64 pa_num = 1;
rmo_pack_mem_attr(addr, size, &mem_attr);
return hal_kernel_put_mem_pa_list(devid, ka_task_get_current_tgid(), &mem_attr, pa_num, &pa_wraper);
}
static int (*const rmo_mem_get_func[ACCESSOR_MAX])(u32 devid, u64 addr, u64 size, u64 *pa_list) = {
[TS_ACCESSOR] = rmo_mem_get_pa_list,
};
static int (*const rmo_mem_put_func[ACCESSOR_MAX])(u32 devid, u64 addr, u64 size, u64 *pa_list) = {
[TS_ACCESSOR] = rmo_mem_put_pa_list,
};
static mem_sharing_func g_mem_sharing_func[ACCESSOR_MAX];
void rmo_mem_sharing_register(mem_sharing_func handle, accessMember_t accessor)
{
if ((accessor >= 0) && (accessor < ACCESSOR_MAX)) {
g_mem_sharing_func[accessor] = handle;
rmo_debug("Register mem dispatch func success. (accessor=%d)\n", accessor);
}
}
KA_EXPORT_SYMBOL_GPL(rmo_mem_sharing_register);
void rmo_mem_sharing_unregister(accessMember_t accessor)
{
if ((accessor >= 0) && (accessor < ACCESSOR_MAX)) {
g_mem_sharing_func[accessor] = NULL;
rmo_debug("Unregister mem dispatch func success. (accessor=%d)\n", accessor);
}
}
KA_EXPORT_SYMBOL_GPL(rmo_mem_sharing_unregister);
static int rmo_mem_sharing_func_proc(u32 devid, struct rmo_mem_raw_addr *raw_addr,
struct rmo_cmd_mem_sharing *mem_sharing)
{
int ret;
if (g_mem_sharing_func[mem_sharing->accessor] == NULL) {
rmo_err("Not register. (devid=%u; accessor=%u)\n", devid, mem_sharing->accessor);
return -ENODEV;
}
ret = g_mem_sharing_func[mem_sharing->accessor](devid, raw_addr, mem_sharing->size);
if (ret != 0) {
rmo_err("Failed to share. (ret=%d; devid=%u; accessor=%u; len=%llu; enable_flag=%u)\n",
ret, devid, mem_sharing->accessor, mem_sharing->size, mem_sharing->enable_flag);
}
return ret;
}
static int rmo_mem_sharing_enable(struct rmo_cmd_mem_sharing *mem_sharing)
{
struct rmo_mem_sharing_info info = {0};
struct rmo_mem_map_addr convert_addr = {0};
u32 devid = mem_sharing->devid;
u32 id = uda_get_host_id();
int tgid = ka_task_get_current_tgid();
u64 paddr;
int ret;
ret = rmo_mem_get_func[mem_sharing->accessor](id, (u64)(uintptr_t)mem_sharing->ptr, mem_sharing->size, &paddr);
if (ret != 0) {
rmo_err("Failed to get addr. (ret=%d; devid=%u; accessor=%u; tgid=%d)\n",
ret, id, mem_sharing->accessor, tgid);
return ret;
}
ret = rmo_mem_addr_map(devid, paddr, mem_sharing->size, &convert_addr);
if (ret != 0) {
rmo_err("Failed to update addr. (ret=%d; devid=%u; accessor=%u; tgid=%d)\n",
ret, devid, mem_sharing->accessor, tgid);
goto err_to_put;
}
ret = rmo_mem_sharing_func_proc(devid, &convert_addr.raw_addr, mem_sharing);
if (ret != 0) {
rmo_err("Failed to share. (ret=%d; devid=%u; accessor=%u; len=%llu; tgid=%d)\n",
ret, devid, mem_sharing->accessor, mem_sharing->size, tgid);
goto err_to_unmap;
}
info.sharing_pa = paddr;
info.convert_addr = convert_addr;
info.mem_shr = *mem_sharing;
ret = rmo_mem_sharing_add_node(res_mem_sharing_ops_domain, tgid, &info);
if (ret != 0) {
rmo_err("Failed to add node. (ret=%d; tgid=%d; devid=%u; accessor=%u; tgid=%d)\n",
ret, tgid, devid, mem_sharing->accessor, tgid);
goto err_to_func;
}
rmo_debug("Enable success. (devid=%u; accessor=%u; len=%llu; enable_flag=%u; tgid=%d)\n",
devid, mem_sharing->accessor, mem_sharing->size, mem_sharing->enable_flag, tgid);
return 0;
err_to_func:
(void)rmo_mem_sharing_func_proc(devid, NULL, mem_sharing);
err_to_unmap:
(void)rmo_mem_addr_unmap(devid, &convert_addr, mem_sharing->size);
err_to_put:
(void)rmo_mem_put_func[mem_sharing->accessor](id, (u64)(uintptr_t)mem_sharing->ptr, mem_sharing->size, &paddr);
return ret;
}
static int rmo_mem_sharing_disable(struct rmo_cmd_mem_sharing *mem_sharing)
{
struct rmo_mem_sharing_info info;
int tgid = ka_task_get_current_tgid();
u32 devid = mem_sharing->devid;
u32 id = uda_get_host_id();
int ret;
info.mem_shr = *mem_sharing;
ret = rmo_mem_sharing_query_node(res_mem_sharing_ops_domain, tgid, &info);
if (ret != 0) {
rmo_err("Failed to find node. (ret=%d; devid=%u; accessor=%u; len=%llu; tgid=%d)\n",
ret, devid, mem_sharing->accessor, mem_sharing->size, tgid);
return ret;
}
ret = rmo_mem_sharing_del_node(res_mem_sharing_ops_domain, tgid, &info);
if (ret != 0) {
rmo_err("Failed to del node. (ret=%d; devid=%u; accessor=%u; len=%llu; tgid=%d)\n",
ret, devid, mem_sharing->accessor, mem_sharing->size, tgid);
return ret;
}
ret = rmo_mem_sharing_func_proc(devid, NULL, mem_sharing);
if (ret != 0) {
(void)rmo_mem_sharing_add_node(res_mem_sharing_ops_domain, tgid, &info);
rmo_err("Failed to share. (ret=%d; devid=%u; accessor=%u; len=%llu; tgid=%d)\n",
ret, devid, mem_sharing->accessor, mem_sharing->size, tgid);
return ret;
}
(void)rmo_mem_addr_unmap(devid, &info.convert_addr, mem_sharing->size);
ret = rmo_mem_put_func[mem_sharing->accessor](id, (u64)(uintptr_t)mem_sharing->ptr, mem_sharing->size,
&info.sharing_pa);
if (ret != 0) {
rmo_warn("Put addr warnning. (ret=%d; devid=%u; accessor=%u)\n",
ret, id, mem_sharing->accessor);
} else {
rmo_debug("Disable success. (devid=%u; accessor=%u; len=%llu; enable_flag=%u; tgid=%d)\n",
devid, mem_sharing->accessor, mem_sharing->size, mem_sharing->enable_flag, tgid);
}
return 0;
}
static int rmo_mem_sharing(struct rmo_cmd_mem_sharing *mem_sharing)
{
if (mem_sharing->enable_flag == 0) {
return rmo_mem_sharing_enable(mem_sharing);
} else {
return rmo_mem_sharing_disable(mem_sharing);
}
}
static int rmo_ioctl_mem_sharing(u32 cmd, unsigned long arg)
{
struct rmo_cmd_mem_sharing *usr_arg = (struct rmo_cmd_mem_sharing __ka_user *)(uintptr_t)arg;
struct rmo_cmd_mem_sharing mem_sharing;
int ret;
ret = (int)ka_base_copy_from_user(&mem_sharing, usr_arg, sizeof(mem_sharing));
if (ret != 0) {
rmo_err("Copy from user failed. (ret=%d)\n", ret);
return -EFAULT;
}
ret = uda_devid_to_udevid(mem_sharing.devid, &mem_sharing.devid);
if (ret != 0) {
rmo_err("Invalid devid. (devid=%u)\n", mem_sharing.devid);
return -ENODEV;
}
if (!uda_is_phy_dev(mem_sharing.devid)) {
return -EOPNOTSUPP;
}
if ((mem_sharing.ptr == NULL) || (mem_sharing.size == 0) || (mem_sharing.size > RMO_MEM_SHARING_MAX_SIZE) ||
(((u64)(uintptr_t)(mem_sharing.ptr) & (KA_MM_PAGE_SIZE - 1)) != 0)) {
rmo_err("Invalid para. (ptr=%p; size=%llu)\n", mem_sharing.ptr, mem_sharing.size);
return -EINVAL;
}
if ((mem_sharing.accessor < 0) || (mem_sharing.accessor >= ACCESSOR_MAX)) {
return -EOPNOTSUPP;
}
if (mem_sharing.side != MEM_HOST_SIDE) {
rmo_err("Invalid memory side. (side=%d)\n", mem_sharing.side);
return -EINVAL;
}
if ((mem_sharing.enable_flag != 0) && (mem_sharing.enable_flag != 1)) {
rmo_err("Invalid enable_flag. (enable_flag=%d)\n", mem_sharing.enable_flag);
return -EINVAL;
}
return rmo_mem_sharing(&mem_sharing);
}
void rmo_mem_sharing_domain_task_exit(u32 udevid, int tgid, struct task_start_time *start_time)
{
rmo_mem_sharing_ctx_destroy(res_mem_sharing_ops_domain, tgid);
}
DECLAER_FEATURE_AUTO_UNINIT_TASK(rmo_mem_sharing_domain_task_exit, FEATURE_LOADER_STAGE_1);
int rmo_mem_sharing_init(void)
{
res_mem_sharing_ops_domain = task_ctx_domain_create("mem_sharing_domain", 0);
if (res_mem_sharing_ops_domain == NULL) {
return -ENOMEM;
}
rmo_register_ioctl_cmd_func(_KA_IOC_NR(RMO_MEM_SHARING), rmo_ioctl_mem_sharing);
return 0;
}
DECLAER_FEATURE_AUTO_INIT(rmo_mem_sharing_init, FEATURE_LOADER_STAGE_6);
void rmo_mem_sharing_uninit(void)
{
task_ctx_domain_destroy(res_mem_sharing_ops_domain);
res_mem_sharing_ops_domain = NULL;
}
DECLAER_FEATURE_AUTO_UNINIT(rmo_mem_sharing_uninit, FEATURE_LOADER_STAGE_6);