e6f7b4c2创建于 1月21日历史提交
'''

数据管理相关任务

'''

import json

from celery_app import celery_app

from web_apps.datasource.db_models import DataSource

from web_apps.datamodel.db_models import DataModel

from web_apps import db, app

from utils.log_utils import get_task_logger

from utils.common_utils import gen_uuid, parse_json

from utils.etl_utils import get_reader_model

from web_apps.rag.services.rag_service import train_datamodel, train_document



@celery_app.task(bind=True)

def self_gen_datasource_model(self, datasource_id):

    '''

    针对数据源自动创建模型

    :return:

    '''

    with app.app_context():

        uuid = self.request.id if self.request.id else gen_uuid()

        worker = self.request.hostname if self.request.hostname else ''

        logger = get_task_logger(p_name='self_gen_datasource_model', task_log_keys={'task_uuid': uuid})

        logger.info(f'任务开始,任务id:{uuid}, 执行worker:{worker}')

        try:

            datasource_obj = db.session.query(DataSource).filter(DataSource.id == datasource_id).first()

            model_info = {

                'source': {

                    "name": "",

                    "type": datasource_obj.type,

                    "conn_conf": parse_json(datasource_obj.conn_conf),

                    "ext_params": {}

                },

                'model': {},

                'extract_info': {

                    'batch_size': 1,

                    'extract_rules': []

                }

            }

            flag, reader = get_reader_model(model_info)

            if flag:

                flag, res = reader.connect()

            if not flag:

                logger.info('数据源连接失败')

                return

            model_list = reader.gen_models()

            for model in model_list:

                exist_objs = db.session.query(DataModel).filter(DataModel.datasource_id == datasource_id,

                                                                DataModel.del_flag == 0,

                                                                DataModel.type == model['type']).all()

                exist_objs = [i for i in exist_objs if parse_json(i.model_conf)['name'] == model['model_conf']['name']]

                if exist_objs == []:

                    model_obj = DataModel(

                        id=gen_uuid(),

                        name=model['model_conf']['name'],

                        datasource_id=datasource_id,

                        type=model['type'],

                        status=1,

                        model_conf=json.dumps(model['model_conf'], ensure_ascii=False),

                        can_interface=1,

                        create_by='system',

                        description="数据源自动建模创建模型"

                    )

                    db.session.add(model_obj)

                    db.session.commit()

                    db.session.flush()

                    logger.info(f"数据源自动建模创建模型{model}成功,模型id为{model_obj.id}")

        except Exception as e:

            logger.exception(e)





@celery_app.task(bind=True)

def self_train_rag_data(self, _id, metadata=None, train_type='document'):

    '''

    rag训练文档

    :return:

    '''

    if metadata is None:

        metadata = {}

    with app.app_context():

        uuid = self.request.id if self.request.id else gen_uuid()

        worker = self.request.hostname if self.request.hostname else ''

        logger = get_task_logger(p_name='self_train_rag_data', task_log_keys={'task_uuid': uuid})

        logger.info(f'任务开始,任务id:{uuid}, 执行worker:{worker}')

        try:

            if train_type == 'document':

                train_document(_id, metadata=metadata)

            else:

                train_datamodel(_id, metadata=metadata)

        except Exception as e:

            logger.exception(e)





if __name__ == '__main__':

    a = self_gen_datasource_model('fdf0938c7d5a44eca94ba093cc8be6c8')