* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <pthread.h>
#include <unistd.h>
#include "ascend_hal_error.h"
#include "queue_kernel_api.h"
#include "que_urma.h"
#include "que_comm_agent.h"
#include "que_compiler.h"
#include "que_mem_merge.h"
static struct que_mem_node *rbtree_can_insert_range_ex(struct rbtree_root *root, struct rb_range_handle *range,
rb_range_handle_func get_range)
{
struct rbtree_node *rb_node = NULL;
struct rbtree_node **tmp = &(root->rbtree_node);
#ifndef COMPILE_UT
while (*tmp != NULL) {
struct rb_range_handle tmp_range;
get_range(*tmp, &tmp_range);
if (range->end < tmp_range.start) {
tmp = &((*tmp)->rbtree_left);
} else if (range->start > tmp_range.end) {
tmp = &((*tmp)->rbtree_right);
} else {
rb_node = *tmp;
return rb_entry(rb_node, struct que_mem_node, node);
}
}
#endif
return NULL;
}
static void que_rb_range_handle_of_mem(struct rbtree_node *node, struct rb_range_handle *range_handle)
{
struct que_mem_node *mem_node = rb_entry(node, struct que_mem_node, node);
range_handle->start = mem_node->va;
range_handle->end = mem_node->va + mem_node->size - 1;
}
static int que_mem_new_node_insert(struct que_mem_merge_ctx *mem_ctx, struct que_mem_node *new_node)
{
int ret;
ret = rbtree_insert_by_range(&mem_ctx->rb_root, &new_node->node, que_rb_range_handle_of_mem);
return (ret != 0) ? DRV_ERROR_BUSY : DRV_ERROR_NONE;
}
static struct que_mem_node *que_mem_find_node(struct que_mem_merge_ctx *mem_ctx, unsigned long long va, unsigned long long size)
{
struct rbtree_node *rb_node = NULL;
struct rb_range_handle range = {.start = va, .end = va + size - 1};
rb_node = rbtree_search_by_range(&mem_ctx->rb_root, &range, que_rb_range_handle_of_mem);
if (que_unlikely(rb_node == NULL)) {
return NULL;
}
return rb_entry(rb_node, struct que_mem_node, node);
}
static struct que_mem_node *que_mem_node_erase(struct que_mem_merge_ctx *mem_ctx, unsigned long long va)
{
struct que_mem_node *mem_node = NULL;
mem_node = que_mem_find_node(mem_ctx, va, 1);
if (que_unlikely(mem_node == NULL)) {
return NULL;
}
_rbtree_erase(&mem_ctx->rb_root, &mem_node->node);
return mem_node;
}
static struct que_mem_node *que_exist_mem_conflict(struct que_mem_merge_ctx *mem_ctx, unsigned long long va, unsigned long long size)
{
struct rb_range_handle range_handle = {.start = va, .end = va + size - 1};
return rbtree_can_insert_range_ex(&mem_ctx->rb_root, &range_handle, que_rb_range_handle_of_mem);
}
static void que_mem_range_merge(struct que_mem_node *mem_node, unsigned long long *va_new, unsigned long long *size_new)
{
unsigned long long new_end_addr = *va_new + *size_new;
unsigned long long curr_end_addr = mem_node->va + mem_node->size;
*va_new = ((mem_node->va > *va_new) ? *va_new : mem_node->va);
*size_new = ((curr_end_addr > new_end_addr) ? (curr_end_addr - *va_new) : (new_end_addr - *va_new));
}
static struct que_mem_node *que_mem_node_alloc(unsigned long long va, unsigned long long size)
{
struct que_mem_node *mem_node = NULL;
mem_node = (struct que_mem_node *)malloc(sizeof(struct que_mem_node));
if (que_unlikely(mem_node == NULL)) {
QUEUE_LOG_ERR("malloc new node failed. (size=%llu)\n", sizeof(struct que_mem_node));
return NULL;
}
RB_CLEAR_NODE(&mem_node->node);
mem_node->va = va;
mem_node->size = size;
mem_node->tseg = NULL;
return mem_node;
}
static void que_mem_node_free(struct que_mem_node *mem_node)
{
free(mem_node);
}
void que_mem_erase_all_node(struct que_mem_merge_ctx *mem_ctx)
{
struct rbtree_node *rb_node = NULL;
struct que_mem_node *mem_node = NULL;
(void)pthread_rwlock_wrlock(&mem_ctx->rwlock);
if (RB_EMPTY_ROOT(&mem_ctx->rb_root) == true) {
goto erase_finish;
}
rb_node = rbtree_erase_one_node(&mem_ctx->rb_root);
while (rb_node != NULL) {
mem_node = rb_entry(rb_node, struct que_mem_node, node);
if (que_unlikely(mem_node->tseg != NULL)) {
if (que_is_share_mem(mem_node->va) == false) {
que_seg_destroy(mem_node->tseg);
}
}
que_mem_node_free(mem_node);
rb_node = rbtree_erase_one_node(&mem_ctx->rb_root);
}
erase_finish:
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return;
}
int que_mem_merge(struct que_mem_merge_ctx *mem_ctx, unsigned int devid, unsigned long long va, unsigned long long size)
{
struct que_mem_node *mem_node = NULL;
struct que_mem_node *confict_node = NULL;
struct que_mem_node *new_mem_node = NULL;
unsigned long long aligned_va = que_align_down(va, (size_t)getpagesize());
unsigned long long aligned_size = que_align_up(va + size, (size_t)getpagesize()) - aligned_va;
unsigned long long judge_va = aligned_va - (size_t)getpagesize();
unsigned long long judge_size = aligned_size + (size_t)getpagesize() + 1;
int ret;
(void)pthread_rwlock_wrlock(&mem_ctx->rwlock);
confict_node = que_exist_mem_conflict(mem_ctx, judge_va, judge_size);
while (confict_node != NULL) {
mem_node = que_mem_node_erase(mem_ctx, confict_node->va);
if (que_unlikely(mem_node == NULL)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
QUEUE_LOG_ERR("merge mem node fail. (devid=%u; va=0x%llx; size=%llu)\n", devid, va, size);
return DRV_ERROR_INNER_ERR;
}
que_mem_range_merge(mem_node, &aligned_va, &aligned_size);
que_mem_node_free(mem_node);
judge_va = aligned_va - (size_t)getpagesize();
judge_size = aligned_size + (size_t)getpagesize() + 1;
confict_node = que_exist_mem_conflict(mem_ctx, judge_va, judge_size);
}
new_mem_node = que_mem_node_alloc(aligned_va, aligned_size);
if (que_unlikely(new_mem_node == NULL)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
QUEUE_LOG_ERR("malloc new node failed. (size=%llu)\n", sizeof(struct que_mem_node));
return DRV_ERROR_OUT_OF_MEMORY;
}
ret = que_mem_new_node_insert(mem_ctx, new_mem_node);
if (que_unlikely(ret != DRV_ERROR_NONE)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
que_mem_node_free(new_mem_node);
QUEUE_LOG_ERR("insert mem node fail. (devid=%u; va=0x%llx; size=%llu)\n", devid, va, size);
return ret;
}
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return DRV_ERROR_NONE;
}
urma_target_seg_t *que_mem_seg_register(struct que_mem_merge_ctx *mem_ctx, unsigned int d2d_flag, struct que_urma_token *token,
unsigned int pin_flg, unsigned int devid, unsigned long long va, unsigned int access)
{
urma_target_seg_t *tseg = NULL;
struct que_mem_node *mem_node = NULL;
(void)pthread_rwlock_wrlock(&mem_ctx->rwlock);
mem_node = que_mem_find_node(mem_ctx, va, 1);
if (que_unlikely(mem_node == NULL)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return NULL;
}
tseg = mem_node->tseg;
if (que_unlikely(tseg != NULL)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return tseg;
}
if (que_is_share_mem(mem_node->va) == true) {
tseg = que_get_urma_ctx_tseg(devid, d2d_flag);
} else {
if (pin_flg == 0) {
tseg = que_nonpin_seg_create(devid, mem_node->va, mem_node->size, access, token, d2d_flag);
} else {
tseg = que_pin_seg_create(devid, mem_node->va, mem_node->size, access, token, d2d_flag);
}
}
if (que_unlikely(tseg == NULL)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
QUEUE_LOG_ERR("que mem seg fail. (devid=%u; va=0x%llx)\n", devid, va);
return NULL;
}
mem_node->tseg = tseg;
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return tseg;
}
void que_mem_seg_unregister(struct que_mem_merge_ctx *mem_ctx, unsigned int devid, unsigned long long va)
{
struct que_mem_node *mem_node = NULL;
(void)pthread_rwlock_wrlock(&mem_ctx->rwlock);
mem_node = que_mem_node_erase(mem_ctx, va);
if (que_unlikely(mem_node == NULL)) {
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return;
}
if (que_unlikely(mem_node->tseg != NULL)) {
if (que_is_share_mem(mem_node->va) == false) {
que_seg_destroy(mem_node->tseg);
}
}
que_mem_node_free(mem_node);
(void)pthread_rwlock_unlock(&mem_ctx->rwlock);
return;
}
void que_mem_ctx_init(struct que_mem_merge_ctx *mem_ctx)
{
(void)pthread_rwlock_init(&mem_ctx->rwlock, NULL);
rbtree_init(&mem_ctx->rb_root);
}