2023-11-20 10:53:18 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: common_serializers.py
|
|
|
|
|
|
@date:2023/11/17 11:00
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
2024-07-01 11:39:07 +00:00
|
|
|
|
import uuid
|
2023-12-25 09:10:59 +00:00
|
|
|
|
from typing import List
|
2023-11-20 10:53:18 +00:00
|
|
|
|
|
2023-12-25 09:10:59 +00:00
|
|
|
|
from django.db.models import QuerySet
|
2024-01-17 08:08:51 +00:00
|
|
|
|
from drf_yasg import openapi
|
|
|
|
|
|
from rest_framework import serializers
|
2023-12-25 09:10:59 +00:00
|
|
|
|
|
2024-07-19 02:34:47 +00:00
|
|
|
|
from common.config.embedding_config import ModelManage
|
2023-12-25 09:10:59 +00:00
|
|
|
|
from common.db.search import native_search
|
2023-11-20 10:53:18 +00:00
|
|
|
|
from common.db.sql_execute import update_execute
|
2024-01-17 08:08:51 +00:00
|
|
|
|
from common.exception.app_exception import AppApiException
|
|
|
|
|
|
from common.mixins.api_mixin import ApiMixin
|
2024-03-04 02:12:18 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2023-11-20 10:53:18 +00:00
|
|
|
|
from common.util.file_util import get_file_content
|
2024-01-19 08:47:18 +00:00
|
|
|
|
from common.util.fork import Fork
|
2024-07-17 09:01:57 +00:00
|
|
|
|
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
|
|
|
|
|
|
from setting.models_provider import get_model
|
2023-11-20 10:53:18 +00:00
|
|
|
|
from smartdoc.conf import PROJECT_DIR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_document_char_length(document_id: str):
|
|
|
|
|
|
update_execute(get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')),
|
|
|
|
|
|
(document_id, document_id))
|
2023-12-25 09:10:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_paragraph(paragraph_list: List[str]):
|
|
|
|
|
|
if paragraph_list is None or len(paragraph_list) == 0:
|
|
|
|
|
|
return []
|
|
|
|
|
|
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
2024-01-17 08:08:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-01-19 08:47:18 +00:00
|
|
|
|
class MetaSerializer(serializers.Serializer):
|
|
|
|
|
|
class WebMeta(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("文档地址"))
|
|
|
|
|
|
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
|
|
|
|
|
error_messages=ErrMessage.char("选择器"))
|
2024-01-19 08:47:18 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
source_url = self.data.get('source_url')
|
|
|
|
|
|
response = Fork(source_url, []).fork()
|
|
|
|
|
|
if response.status == 500:
|
2024-03-04 03:01:58 +00:00
|
|
|
|
raise AppApiException(500, f"url错误,无法解析【{source_url}】")
|
2024-01-19 08:47:18 +00:00
|
|
|
|
|
|
|
|
|
|
class BaseMeta(serializers.Serializer):
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-01-17 08:08:51 +00:00
|
|
|
|
class BatchSerializer(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
|
|
|
|
|
error_messages=ErrMessage.char("id列表"))
|
2024-01-17 08:08:51 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, model=None, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if model is not None:
|
|
|
|
|
|
id_list = self.data.get('id_list')
|
|
|
|
|
|
model_list = QuerySet(model).filter(id__in=id_list)
|
|
|
|
|
|
if len(model_list) != len(id_list):
|
|
|
|
|
|
model_id_list = [str(m.id) for m in model_list]
|
|
|
|
|
|
error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
|
|
|
|
|
|
raise AppApiException(500, f"id不正确:{error_id_list}")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
|
|
|
|
|
title="主键id列表",
|
|
|
|
|
|
description="主键id列表")
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
2024-07-01 11:39:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProblemParagraphObject:
|
|
|
|
|
|
def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str):
|
|
|
|
|
|
self.dataset_id = dataset_id
|
|
|
|
|
|
self.document_id = document_id
|
|
|
|
|
|
self.paragraph_id = paragraph_id
|
|
|
|
|
|
self.problem_content = problem_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict):
|
|
|
|
|
|
if content in problem_content_dict:
|
|
|
|
|
|
return problem_content_dict.get(content)[0], document_id, paragraph_id
|
|
|
|
|
|
exists = [row for row in exists_problem_list if row.content == content]
|
|
|
|
|
|
if len(exists) > 0:
|
|
|
|
|
|
problem_content_dict[content] = exists[0], False
|
|
|
|
|
|
return exists[0], document_id, paragraph_id
|
|
|
|
|
|
else:
|
|
|
|
|
|
problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
|
|
|
|
|
|
problem_content_dict[content] = problem, True
|
|
|
|
|
|
return problem, document_id, paragraph_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProblemParagraphManage:
|
|
|
|
|
|
def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id):
|
|
|
|
|
|
self.dataset_id = dataset_id
|
|
|
|
|
|
self.problemParagraphObjectList = problemParagraphObjectList
|
|
|
|
|
|
|
|
|
|
|
|
def to_problem_model_list(self):
|
|
|
|
|
|
problem_list = [item.problem_content for item in self.problemParagraphObjectList]
|
|
|
|
|
|
exists_problem_list = []
|
|
|
|
|
|
if len(self.problemParagraphObjectList) > 0:
|
|
|
|
|
|
# 查询到已存在的问题列表
|
|
|
|
|
|
exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id,
|
|
|
|
|
|
content__in=problem_list).all()
|
|
|
|
|
|
problem_content_dict = {}
|
|
|
|
|
|
problem_model_list = [
|
|
|
|
|
|
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id,
|
|
|
|
|
|
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
|
|
|
|
|
|
problemParagraphObject in self.problemParagraphObjectList]
|
|
|
|
|
|
|
|
|
|
|
|
problem_paragraph_mapping_list = [
|
|
|
|
|
|
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
|
|
|
|
|
paragraph_id=paragraph_id,
|
|
|
|
|
|
dataset_id=self.dataset_id) for
|
|
|
|
|
|
problem_model, document_id, paragraph_id in problem_model_list]
|
|
|
|
|
|
|
|
|
|
|
|
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
|
|
|
|
|
is_create], problem_paragraph_mapping_list
|
|
|
|
|
|
return result
|
2024-07-17 09:01:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
|
|
|
|
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
|
|
|
|
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
|
|
|
|
|
raise Exception("知识库未向量模型不一致")
|
|
|
|
|
|
if len(dataset_list) == 0:
|
|
|
|
|
|
raise Exception("知识库设置错误,请重新设置知识库")
|
2024-07-26 01:52:25 +00:00
|
|
|
|
return ModelManage.get_model(str(dataset_list[0].embedding_mode_id),
|
|
|
|
|
|
lambda _id: get_model(dataset_list[0].embedding_mode))
|
2024-07-17 09:01:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding_model_by_dataset_id(dataset_id: str):
|
2024-07-18 02:26:16 +00:00
|
|
|
|
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
2024-08-15 09:17:25 +00:00
|
|
|
|
return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode))
|
2024-07-18 07:44:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding_model_by_dataset(dataset):
|
2024-07-26 01:52:25 +00:00
|
|
|
|
return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode))
|
2024-08-21 06:46:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding_model_id_by_dataset_id(dataset_id):
|
|
|
|
|
|
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
|
|
|
|
|
return str(dataset.embedding_mode_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List):
|
|
|
|
|
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
|
|
|
|
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
|
|
|
|
|
raise Exception("知识库未向量模型不一致")
|
|
|
|
|
|
if len(dataset_list) == 0:
|
|
|
|
|
|
raise Exception("知识库设置错误,请重新设置知识库")
|
|
|
|
|
|
return str(dataset_list[0].embedding_mode_id)
|