'''
rag模块数据模型
'''
import pickle
from web_apps import db
from models import BaseModel
class Dataset(BaseModel):
'''
数据集表
'''
__tablename__ = 'rag_dataset'
id = db.Column(db.String(36), primary_key=True, nullable=False, default='', comment='id')
name = db.Column(db.String(200), nullable=True, default='', comment='名称', index=True)
built_in = db.Column(db.SmallInteger, nullable=True, default=0, comment='是否内置 1是 0不是')
status = db.Column(db.SmallInteger, nullable=True, default=1, comment='状态( 1为启用 0禁用)', index=True)
class Document(BaseModel):
'''
文档表
'''
__tablename__ = 'rag_document'
id = db.Column(db.String(36), primary_key=True, nullable=False, default='', comment='id')
dataset_id = db.Column(db.String(36), nullable=True, default='', comment='数据集id', index=True)
document_type = db.Column(db.String(32), nullable=True, default='', comment='文档类型')
name = db.Column(db.String(200), nullable=True, default='', comment='名称', index=True)
status = db.Column(db.SmallInteger, nullable=True, default=1, comment='状态( 1待训练,2训练中,3训练成功,4训练失败)', index=True)
meta_data = db.Column(db.Text, nullable=True, default='{}', comment='文档元信息')
chunk_strategy = db.Column(db.Text, nullable=True, default='{}', comment='分段策略')
class Chunk(BaseModel):
'''
分段chunk
'''
__tablename__ = 'rag_chunk'
id = db.Column(db.String(36), primary_key=True, nullable=False, default='', comment='主键')
dataset_id = db.Column(db.String(36), nullable=True, default='', comment='数据集id', index=True)
document_id = db.Column(db.String(36), nullable=True, default='', comment='文档id', index=True)
datasource_id = db.Column(db.String(36), nullable=True, default='', comment='数据源id', index=True)
datamodel_id = db.Column(db.String(36), nullable=True, default='', comment='数据模型id', index=True)
chunk_type = db.Column(db.String(32), nullable=True, default='chunk', comment='类型(chunk:文本分段 qa:问答对)', index=True)
question = db.Column(db.Text, nullable=True, default='', comment='问题')
question_hash = db.Column(db.String(32), nullable=True, default='', comment='问题hash', index=True)
answer = db.Column(db.Text, nullable=True, default='', comment='问题回答')
content = db.Column(db.Text, nullable=True, default='', comment='内容')
hash = db.Column(db.String(32), nullable=True, default='', comment='内容hash', index=True)
position = db.Column(db.Integer, nullable=True, default=1, comment='分段位置')
status = db.Column(db.SmallInteger, nullable=True, default=1, comment='状态( 1已同步 0未同步)', index=True)
star_flag = db.Column(db.SmallInteger, default=0, comment='标星状态( 1为标星 0没有标星)', index=True)
class Embedding(db.Model):
__tablename__ = 'rag_embedding'
id = db.Column(db.String(36), primary_key=True)
hash = db.Column(db.String(32), nullable=False, index=True)
embedding = db.Column(db.LargeBinary, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
def set_embedding(self, embedding_data: list[float]):
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
def get_embedding(self) -> list[float]:
return pickle.loads(self.embedding)
if __name__ == '__main__':
from web_apps import app
with app.app_context():
db.create_all()
db.session.commit()
db.session.flush()