* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 and
* only version 2 as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*/
#ifndef PBL_TASK_CTX_H
#define PBL_TASK_CTX_H
#include <linux/version.h>
#include <linux/kref.h>
#include <linux/list.h>
#include <linux/spinlock.h>
#include <linux/hashtable.h>
#include <linux/sched.h>
#include <linux/sched/task.h>
#include <linux/sched/signal.h>
#include <linux/slab.h>
#include <linux/vmalloc.h>
#include <linux/mutex.h>
#include <linux/pid.h>
#include <linux/seq_file.h>
#include "securec.h"
#include "pbl_kref_safe.h"
usage : Process Context Management with reference counting
1. call task_ctx_domain_create create a domain when module_init
2. call task_ctx_create create a task context in the domain when process open a file
3. call task_ctx_get and task_ctx_put in business
4. call task_ctx_destroy when process exit or close a file
5. call task_ctx_for_each_safe traverse a domain
6. task_ctx_domain_show
7. call task_ctx_domain_destroy destroy domain when module_exit
*/
#define TASK_CTX_HASH_BIT 6
#define TASK_CTX_HASH_MAX (1 << TASK_CTX_HASH_BIT)
#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)
#define TASK_START_TIME_EQUAL(time, start_time) ((time) == (start_time))
#else
#define TASK_START_TIME_EQUAL(time, start_time) (timespec_equal((&(time)), (&(start_time))))
#endif
#define TASK_CTX_DOMAIN_FLAG_ACTIVE (1ULL << 0ULL)
struct task_ctx_domain {
struct kref_safe ref;
DECLARE_HASHTABLE(ctx_htable, TASK_CTX_HASH_BIT);
rwlock_t ctx_lock[TASK_CTX_HASH_MAX];
struct mutex mutex;
const char *name;
atomic_t task_num;
u32 task_max_num;
u64 flags;
};
struct task_start_time {
#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)
u64 time;
#else
struct timespec time;
#endif
};
struct task_id_entity {
int tgid;
struct task_start_time start_time;
};
struct task_ctx {
struct kref_safe ref;
struct task_start_time start_time;
int tgid;
char name[TASK_COMM_LEN];
struct hlist_node link;
struct mutex mutex;
spinlock_t lock;
struct task_ctx_domain *domain;
void *priv;
void (*release)(struct task_ctx *ctx);
};
static inline bool task_ctx_domain_is_active(struct task_ctx_domain *domain)
{
return ((domain->flags & TASK_CTX_DOMAIN_FLAG_ACTIVE) != 0);
}
static inline void task_ctx_domain_set_active(struct task_ctx_domain *domain)
{
mutex_lock(&domain->mutex);
domain->flags |= TASK_CTX_DOMAIN_FLAG_ACTIVE;
mutex_unlock(&domain->mutex);
}
static inline int task_ctx_domain_set_inactive(struct task_ctx_domain *domain)
{
mutex_lock(&domain->mutex);
if (atomic_read(&domain->task_num) != 0) {
mutex_unlock(&domain->mutex);
return -EBUSY;
}
domain->flags &= (~TASK_CTX_DOMAIN_FLAG_ACTIVE);
mutex_unlock(&domain->mutex);
return 0;
}
static inline void task_ctx_domain_init(struct task_ctx_domain *domain, const char *name, u32 task_max_num)
{
u32 i;
domain->name = name;
domain->task_max_num = task_max_num;
domain->flags = TASK_CTX_DOMAIN_FLAG_ACTIVE;
hash_init(domain->ctx_htable);
mutex_init(&domain->mutex);
for (i = 0; i < TASK_CTX_HASH_MAX; i++) {
rwlock_init(&domain->ctx_lock[i]);
}
}
static inline struct task_ctx_domain *task_ctx_domain_create(const char *name, u32 task_max_num)
{
struct task_ctx_domain *domain = NULL;
domain = vzalloc(sizeof(*domain));
if (domain == NULL) {
return NULL;
}
task_ctx_domain_init(domain, name, task_max_num);
return domain;
}
static inline void task_ctx_domain_destroy(struct task_ctx_domain *domain)
{
mutex_destroy(&domain->mutex);
vfree(domain);
}
static inline u32 task_get_bkt_from_tgid(int tgid)
{
return hash_min(tgid, TASK_CTX_HASH_BIT);
}
static inline void task_ctx_add(struct task_ctx_domain *domain, struct task_ctx *ctx)
{
u32 bkt = task_get_bkt_from_tgid(ctx->tgid);
hlist_add_head(&ctx->link, &domain->ctx_htable[bkt]);
}
static inline void task_ctx_del(struct task_ctx *ctx)
{
hlist_del_init(&ctx->link);
}
static inline struct task_ctx *_task_ctx_find(struct task_ctx_domain *domain, int tgid)
{
u32 bkt = task_get_bkt_from_tgid(tgid);
struct task_ctx *ctx = NULL;
hlist_for_each_entry(ctx, &domain->ctx_htable[bkt], link) {
if (ctx->tgid == tgid) {
return ctx;
}
}
return NULL;
}
static inline int task_ctx_add_to_domain(struct task_ctx_domain *domain, struct task_ctx *ctx)
{
u32 bkt = task_get_bkt_from_tgid(ctx->tgid);
write_lock_bh(&domain->ctx_lock[bkt]);
if (_task_ctx_find(domain, ctx->tgid) != NULL) {
write_unlock_bh(&domain->ctx_lock[bkt]);
return -EEXIST;
}
task_ctx_add(domain, ctx);
write_unlock_bh(&domain->ctx_lock[bkt]);
atomic_inc(&domain->task_num);
return 0;
}
static inline void _task_ctx_release(struct task_ctx_domain *domain, struct task_ctx *ctx)
{
if (ctx->release != NULL) {
ctx->release(ctx);
}
mutex_destroy(&ctx->mutex);
vfree(ctx);
}
static inline void task_ctx_release(struct kref_safe *ref)
{
struct task_ctx *ctx = container_of(ref, struct task_ctx, ref);
_task_ctx_release(ctx->domain, ctx);
}
static inline struct task_ctx *_task_ctx_get_inner(struct task_ctx_domain *domain, int tgid, struct task_start_time *time)
{
struct task_ctx *ctx = _task_ctx_find(domain, tgid);
if (ctx != NULL) {
if ((time != NULL) && (TASK_START_TIME_EQUAL(ctx->start_time.time, time->time) == false)) {
return NULL;
}
if (kref_safe_get_unless_zero(&ctx->ref) == 0) {
return NULL;
}
}
return ctx;
}
static inline struct task_ctx *_task_ctx_get(struct task_ctx_domain *domain, int tgid, struct task_start_time *time)
{
u32 bkt = task_get_bkt_from_tgid(tgid);
struct task_ctx *ctx = NULL;
read_lock_bh(&domain->ctx_lock[bkt]);
ctx = _task_ctx_get_inner(domain, tgid, time);
read_unlock_bh(&domain->ctx_lock[bkt]);
return ctx;
}
static inline struct task_ctx *task_ctx_get(struct task_ctx_domain *domain, int tgid)
{
return _task_ctx_get(domain, tgid, NULL);
}
static inline struct task_ctx *task_ctx_time_check_get(struct task_ctx_domain *domain, int tgid,
struct task_start_time *time)
{
return _task_ctx_get(domain, tgid, time);
}
static inline void task_ctx_put(struct task_ctx *ctx)
{
kref_safe_put(&ctx->ref, task_ctx_release);
}
static inline void *task_ctx_priv(struct task_ctx *ctx)
{
return ctx->priv;
}
static inline struct task_struct *_task_get_by_tgid(int tgid)
{
struct pid *pid = NULL;
struct task_struct *task = NULL;
pid = get_pid(find_pid_ns(tgid, &init_pid_ns));
if (pid == NULL) {
return NULL;
}
task = get_pid_task(pid, PIDTYPE_PID);
put_pid(pid);
return task;
}
static inline struct task_struct *task_get_by_tgid(int tgid)
{
struct task_struct *task = NULL;
rcu_read_lock();
task = _task_get_by_tgid(tgid);
rcu_read_unlock();
return task;
}
static inline int task_get_start_time_by_tgid(int tgid, struct task_start_time *start_time)
{
struct task_struct *task = task_get_by_tgid(tgid);
if (task != NULL) {
start_time->time = task->group_leader->start_time;
put_task_struct(task);
return 0;
}
return -ESRCH;
}
static inline int _task_ctx_init(struct task_ctx_domain *domain, struct task_ctx *ctx,
int tgid, void *priv, void (*release)(struct task_ctx *ctx))
{
(void)strncpy_s(ctx->name, TASK_COMM_LEN, current->comm, TASK_COMM_LEN);
(void)task_get_start_time_by_tgid(tgid, &ctx->start_time);
ctx->tgid = tgid;
ctx->priv = priv;
mutex_init(&ctx->mutex);
spin_lock_init(&ctx->lock);
ctx->domain = domain;
ctx->release = release;
kref_safe_init(&ctx->ref);
return task_ctx_add_to_domain(domain, ctx);
}
static inline int task_ctx_init(struct task_ctx_domain *domain, struct task_ctx *ctx,
int tgid, void *priv, void (*release)(struct task_ctx *ctx))
{
if ((domain->task_max_num > 0) && (atomic_read(&domain->task_num) >= (int)domain->task_max_num)) {
return -EFAULT;
}
return _task_ctx_init(domain, ctx, tgid, priv, release);
}
static inline int task_ctx_create(struct task_ctx_domain *domain,
int tgid, void *priv, void (*release)(struct task_ctx *ctx))
{
struct task_ctx *ctx = NULL;
int ret;
if (task_ctx_domain_is_active(domain) == false) {
return -EPERM;
}
ctx = vzalloc(sizeof(*ctx));
if (ctx == NULL) {
return -ENOMEM;
}
ret = task_ctx_init(domain, ctx, tgid, priv, release);
if (ret != 0) {
vfree(ctx);
}
return ret;
}
static inline int task_ctx_create_ex(struct task_ctx_domain *domain,
int tgid, void *priv, void (*release)(struct task_ctx *ctx), struct task_ctx **ctx)
{
int ret = -ENOMEM;
if (task_ctx_domain_is_active(domain) == false) {
return -EPERM;
}
*ctx = vzalloc(sizeof(**ctx));
if (*ctx != NULL) {
ret = task_ctx_init(domain, *ctx, tgid, priv, release);
if (ret != 0) {
vfree(*ctx);
*ctx = NULL;
}
}
return ret;
}
static inline void task_ctx_destroy(struct task_ctx *ctx)
{
u32 bkt = task_get_bkt_from_tgid(ctx->tgid);
atomic_dec(&ctx->domain->task_num);
write_lock_bh(&ctx->domain->ctx_lock[bkt]);
task_ctx_del(ctx);
write_unlock_bh(&ctx->domain->ctx_lock[bkt]);
task_ctx_put(ctx);
}
static inline struct task_ctx *task_ctx_get_first(struct hlist_head *head, rwlock_t *lock)
{
struct task_ctx *ctx = NULL;
read_lock_bh(lock);
ctx = hlist_entry_safe(head->first, struct task_ctx, link);
if (ctx != NULL) {
if (kref_safe_get_unless_zero(&ctx->ref) == 0) {
ctx = NULL;
}
}
read_unlock_bh(lock);
return ctx;
}
static inline struct task_ctx *task_ctx_get_next(struct task_ctx *ctx_from, rwlock_t *lock)
{
struct task_ctx *ctx = NULL;
read_lock_bh(lock);
ctx = hlist_entry_safe(ctx_from->link.next, struct task_ctx, link);
if (ctx != NULL) {
if (kref_safe_get_unless_zero(&ctx->ref) == 0) {
ctx = NULL;
}
}
read_unlock_bh(lock);
return ctx;
}
static inline void _task_ctx_for_each_safe(struct hlist_head *head, rwlock_t *lock,
void *priv, void (*func)(struct task_ctx *ctx, void *priv))
{
struct task_ctx *ctx = NULL;
ctx = task_ctx_get_first(head, lock);
while (ctx != NULL) {
struct task_ctx *next_ctx = task_ctx_get_next(ctx, lock);
func(ctx, priv);
task_ctx_put(ctx);
ctx = next_ctx;
}
}
static inline void task_ctx_for_each_safe(struct task_ctx_domain *domain,
void *priv, void (*func)(struct task_ctx *ctx, void *priv))
{
u32 bkt;
for (bkt = 0; bkt < TASK_CTX_HASH_MAX; bkt++) {
_task_ctx_for_each_safe(&domain->ctx_htable[bkt], &domain->ctx_lock[bkt], priv, func);
}
}
static inline void task_ctx_item_show(struct task_ctx *ctx, void *priv)
{
struct seq_file *seq = (struct seq_file *)priv;
seq_printf(seq, " name: %s tgid %d ref %d\n", ctx->name, ctx->tgid, kref_safe_read(&ctx->ref));
}
static inline void task_ctx_domain_show(struct task_ctx_domain *domain, struct seq_file *seq)
{
seq_printf(seq, "domain: %s task_num %d task_max_num %u\n",
domain->name, atomic_read(&domain->task_num), domain->task_max_num);
task_ctx_for_each_safe(domain, seq, task_ctx_item_show);
seq_printf(seq, "\n");
}
static inline int _task_ctx_lock_call_func(struct task_ctx_domain *domain, int tgid,
struct task_start_time *time, int (*func)(struct task_ctx *ctx, void *para), void *para)
{
int ret = -ESRCH;
struct task_ctx *ctx = _task_ctx_get(domain, tgid, time);
if (ctx != NULL) {
mutex_lock(&ctx->mutex);
ret = func(ctx, para);
mutex_unlock(&ctx->mutex);
task_ctx_put(ctx);
}
return ret;
}
static inline int task_ctx_lock_call_func(struct task_ctx_domain *domain, int tgid,
int (*func)(struct task_ctx *ctx, void *para), void *para)
{
return _task_ctx_lock_call_func(domain, tgid, NULL, func, para);
}
static inline int task_ctx_time_check_lock_call_func(struct task_ctx_domain *domain, int tgid,
struct task_start_time *time, int (*func)(struct task_ctx *ctx, void *para), void *para)
{
return _task_ctx_lock_call_func(domain, tgid, time, func, para);
}
static inline int _task_ctx_lock_call_func_non_block(struct task_ctx_domain *domain, int tgid,
struct task_start_time *time, int (*func)(struct task_ctx *ctx, void *para), void *para)
{
int ret = -ESRCH;
struct task_ctx *ctx = _task_ctx_get(domain, tgid, time);
if (ctx != NULL) {
spin_lock_bh(&ctx->lock);
ret = func(ctx, para);
spin_unlock_bh(&ctx->lock);
task_ctx_put(ctx);
}
return ret;
}
static inline int task_ctx_lock_call_func_non_block(struct task_ctx_domain *domain, int tgid,
int (*func)(struct task_ctx *ctx, void *para), void *para)
{
return _task_ctx_lock_call_func_non_block(domain, tgid, NULL, func, para);
}
static inline int task_ctx_time_check_lock_call_func_non_block(struct task_ctx_domain *domain, int tgid,
struct task_start_time *time, int (*func)(struct task_ctx *ctx, void *para), void *para)
{
return _task_ctx_lock_call_func_non_block(domain, tgid, time, func, para);
}
static inline bool task_is_exit(int tgid, struct task_start_time *start_time)
{
bool is_exit = true;
struct task_struct *task = task_get_by_tgid(tgid);
if (task != NULL) {
if ((TASK_START_TIME_EQUAL(task->group_leader->start_time, start_time->time)) && !(task->flags & PF_EXITING)) {
is_exit = false;
}
put_task_struct(task);
}
return is_exit;
}
static inline bool dbl_task_is_dead(int tgid, struct task_start_time *start_time)
{
bool is_dead = true;
struct task_struct *task = task_get_by_tgid(tgid);
if (task != NULL) {
is_dead = !(TASK_START_TIME_EQUAL(task->group_leader->start_time, start_time->time));
put_task_struct(task);
}
return is_dead;
}
static inline int task_get_tgid_by_vpid(int vpid, int *tgid)
{
struct task_struct *task = NULL;
rcu_read_lock();
task = get_pid_task(find_vpid(vpid), PIDTYPE_PID);
rcu_read_unlock();
if (task == NULL) {
return -ESRCH;
}
*tgid = (int)task_tgid_nr(task);
put_task_struct(task);
return 0;
}
static inline int task_get_parent_tgid_by_vpid(int vpid, int *tgid)
{
struct task_struct *task = NULL;
rcu_read_lock();
task = get_pid_task(find_vpid(vpid), PIDTYPE_PID);
rcu_read_unlock();
if (task == NULL) {
return -ESRCH;
}
*tgid = (int)task->real_parent->tgid;
put_task_struct(task);
return 0;
}
static inline bool task_is_current_child_or_current(int vpid)
{
int parent_tgid;
if ((int)task_tgid_vnr(current) == vpid) {
return true;
}
if (task_get_parent_tgid_by_vpid(vpid, &parent_tgid) != 0) {
return false;
}
return (task_tgid_nr(current) == parent_tgid);
}
#endif