2023-10-09 11:03:41 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: document_serializers.py
|
|
|
|
|
|
@date:2023/9/22 13:43
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
2024-01-03 03:51:48 +00:00
|
|
|
|
import logging
|
2023-10-24 12:24:32 +00:00
|
|
|
|
import os
|
2024-01-03 03:51:48 +00:00
|
|
|
|
import traceback
|
2023-10-09 11:03:41 +00:00
|
|
|
|
import uuid
|
|
|
|
|
|
from functools import reduce
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from typing import List, Dict
|
2023-10-09 11:03:41 +00:00
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from django.db import transaction
|
2023-10-09 11:03:41 +00:00
|
|
|
|
from django.db.models import QuerySet
|
|
|
|
|
|
from drf_yasg import openapi
|
|
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from common.db.search import native_search, native_page_search
|
2024-01-03 07:40:37 +00:00
|
|
|
|
from common.event.common import work_thread_pool
|
2024-01-17 08:08:51 +00:00
|
|
|
|
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs
|
2023-10-09 11:03:41 +00:00
|
|
|
|
from common.exception.app_exception import AppApiException
|
|
|
|
|
|
from common.mixins.api_mixin import ApiMixin
|
2023-12-21 08:55:11 +00:00
|
|
|
|
from common.util.common import post
|
2024-03-04 02:12:18 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from common.util.file_util import get_file_content
|
2024-01-03 03:51:48 +00:00
|
|
|
|
from common.util.fork import Fork
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from common.util.split_model import SplitModel, get_split_model
|
2024-03-11 09:28:05 +00:00
|
|
|
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping
|
2024-01-19 08:47:18 +00:00
|
|
|
|
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
|
|
|
|
|
from smartdoc.conf import PROJECT_DIR
|
2023-10-09 11:03:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-01-19 08:47:18 +00:00
|
|
|
|
class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
meta = serializers.DictField(required=False)
|
2024-03-04 02:12:18 +00:00
|
|
|
|
name = serializers.CharField(required=False, max_length=128, min_length=1,
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档名称"))
|
|
|
|
|
|
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档是否可用"))
|
2024-01-19 08:47:18 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_meta_valid_map():
|
|
|
|
|
|
dataset_meta_valid_map = {
|
|
|
|
|
|
Type.base: MetaSerializer.BaseMeta,
|
|
|
|
|
|
Type.web: MetaSerializer.WebMeta
|
|
|
|
|
|
}
|
|
|
|
|
|
return dataset_meta_valid_map
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, document: Document = None):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if 'meta' in self.data and self.data.get('meta') is not None:
|
|
|
|
|
|
dataset_meta_valid_map = self.get_meta_valid_map()
|
|
|
|
|
|
valid_class = dataset_meta_valid_map.get(document.type)
|
|
|
|
|
|
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-01-17 08:08:51 +00:00
|
|
|
|
class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
source_url_list = serializers.ListField(required=True,
|
|
|
|
|
|
child=serializers.CharField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档地址")),
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档地址列表"))
|
|
|
|
|
|
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
|
|
|
|
|
"选择器"))
|
2024-01-17 08:08:51 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
required=['source_url_list'],
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
|
|
|
|
|
|
items=openapi.Schema(type=openapi.TYPE_STRING)),
|
|
|
|
|
|
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称")
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
2023-10-09 11:03:41 +00:00
|
|
|
|
name = serializers.CharField(required=True,
|
2024-03-04 02:12:18 +00:00
|
|
|
|
error_messages=ErrMessage.char("文档名称"),
|
|
|
|
|
|
max_length=128,
|
|
|
|
|
|
min_length=1)
|
2023-10-09 11:03:41 +00:00
|
|
|
|
|
2023-12-12 07:44:21 +00:00
|
|
|
|
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
2023-10-09 11:03:41 +00:00
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
2023-10-24 12:24:32 +00:00
|
|
|
|
required=['name', 'paragraphs'],
|
2023-10-09 11:03:41 +00:00
|
|
|
|
properties={
|
|
|
|
|
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
|
|
|
|
|
|
'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
|
2023-10-24 12:24:32 +00:00
|
|
|
|
items=ParagraphSerializers.Create.get_request_body_api())
|
2023-10-09 11:03:41 +00:00
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
class Query(ApiMixin, serializers.Serializer):
|
2023-12-18 03:32:29 +00:00
|
|
|
|
# 知识库id
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True,
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
|
|
|
|
|
"知识库id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
name = serializers.CharField(required=False, max_length=128,
|
|
|
|
|
|
min_length=1,
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档名称"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def get_query_set(self):
|
|
|
|
|
|
query_set = QuerySet(model=Document)
|
|
|
|
|
|
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
|
|
|
|
|
|
if 'name' in self.data and self.data.get('name') is not None:
|
|
|
|
|
|
query_set = query_set.filter(**{'name__contains': self.data.get('name')})
|
2023-12-14 02:43:34 +00:00
|
|
|
|
query_set = query_set.order_by('-create_time')
|
2023-10-24 12:24:32 +00:00
|
|
|
|
return query_set
|
|
|
|
|
|
|
|
|
|
|
|
def list(self, with_valid=False):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
query_set = self.get_query_set()
|
|
|
|
|
|
return native_search(query_set, select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
|
|
|
|
|
|
|
|
|
|
|
|
def page(self, current_page, page_size):
|
|
|
|
|
|
query_set = self.get_query_set()
|
|
|
|
|
|
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='name',
|
|
|
|
|
|
in_=openapi.IN_QUERY,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=False,
|
|
|
|
|
|
description='文档名称')]
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_response_body_api():
|
|
|
|
|
|
return openapi.Schema(type=openapi.TYPE_ARRAY,
|
|
|
|
|
|
title="文档列表", description="文档列表",
|
|
|
|
|
|
items=DocumentSerializers.Operate.get_response_body_api())
|
|
|
|
|
|
|
2024-01-03 03:51:48 +00:00
|
|
|
|
class Sync(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档id"))
|
2024-01-03 03:51:48 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
document_id = self.data.get('document_id')
|
|
|
|
|
|
first = QuerySet(Document).filter(id=document_id).first()
|
|
|
|
|
|
if first is None:
|
|
|
|
|
|
raise AppApiException(500, "文档id不存在")
|
|
|
|
|
|
if first.type != Type.web:
|
|
|
|
|
|
raise AppApiException(500, "只有web站点类型才支持同步")
|
|
|
|
|
|
|
2024-01-03 07:40:37 +00:00
|
|
|
|
def sync(self, with_valid=True, with_embedding=True):
|
2024-01-03 03:51:48 +00:00
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
document_id = self.data.get('document_id')
|
|
|
|
|
|
document = QuerySet(Document).filter(id=document_id).first()
|
2024-01-17 08:08:51 +00:00
|
|
|
|
if document.type != Type.web:
|
|
|
|
|
|
return True
|
2024-01-03 03:51:48 +00:00
|
|
|
|
try:
|
|
|
|
|
|
document.status = Status.embedding
|
|
|
|
|
|
document.save()
|
|
|
|
|
|
source_url = document.meta.get('source_url')
|
2024-01-24 03:23:16 +00:00
|
|
|
|
selector_list = document.meta.get('selector').split(
|
|
|
|
|
|
" ") if 'selector' in document.meta and document.meta.get('selector') is not None else []
|
2024-01-03 03:51:48 +00:00
|
|
|
|
result = Fork(source_url, selector_list).fork()
|
|
|
|
|
|
if result.status == 200:
|
|
|
|
|
|
# 删除段落
|
|
|
|
|
|
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
|
|
|
|
|
# 删除问题
|
2024-03-11 09:28:05 +00:00
|
|
|
|
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
2024-01-03 03:51:48 +00:00
|
|
|
|
# 删除向量库
|
|
|
|
|
|
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
|
|
|
|
|
paragraphs = get_split_model('web.md').parse(result.content)
|
|
|
|
|
|
document.char_length = reduce(lambda x, y: x + y,
|
|
|
|
|
|
[len(p.get('content')) for p in paragraphs],
|
|
|
|
|
|
0)
|
|
|
|
|
|
document.save()
|
|
|
|
|
|
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
|
|
|
|
|
|
|
|
|
|
|
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
|
|
|
|
|
problem_model_list = document_paragraph_model.get('problem_model_list')
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
2024-01-03 03:51:48 +00:00
|
|
|
|
# 批量插入段落
|
|
|
|
|
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
|
|
|
|
|
# 批量插入问题
|
|
|
|
|
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
2024-03-11 09:28:05 +00:00
|
|
|
|
# 插入关联问题
|
|
|
|
|
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
|
|
|
|
|
problem_paragraph_mapping_list) > 0 else None
|
2024-01-03 07:40:37 +00:00
|
|
|
|
# 向量化
|
|
|
|
|
|
if with_embedding:
|
|
|
|
|
|
ListenerManagement.embedding_by_document_signal.send(document_id)
|
2024-01-03 03:51:48 +00:00
|
|
|
|
else:
|
|
|
|
|
|
document.status = Status.error
|
|
|
|
|
|
document.save()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
|
|
|
|
|
document.status = Status.error
|
|
|
|
|
|
document.save()
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
class Operate(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='dataset_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
2023-12-18 03:32:29 +00:00
|
|
|
|
description='知识库id'),
|
2023-10-24 12:24:32 +00:00
|
|
|
|
openapi.Parameter(name='document_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='文档id')
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
document_id = self.data.get('document_id')
|
|
|
|
|
|
if not QuerySet(Document).filter(id=document_id).exists():
|
|
|
|
|
|
raise AppApiException(500, "文档id不存在")
|
|
|
|
|
|
|
|
|
|
|
|
def one(self, with_valid=False):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
query_set = QuerySet(model=Document)
|
|
|
|
|
|
query_set = query_set.filter(**{'id': self.data.get("document_id")})
|
|
|
|
|
|
return native_search(query_set, select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True)
|
|
|
|
|
|
|
|
|
|
|
|
def edit(self, instance: Dict, with_valid=False):
|
|
|
|
|
|
if with_valid:
|
2024-01-19 08:47:18 +00:00
|
|
|
|
self.is_valid(raise_exception=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
_document = QuerySet(Document).get(id=self.data.get("document_id"))
|
2024-01-19 08:47:18 +00:00
|
|
|
|
if with_valid:
|
|
|
|
|
|
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
|
|
|
|
|
|
update_keys = ['name', 'is_active', 'meta']
|
2023-10-24 12:24:32 +00:00
|
|
|
|
for update_key in update_keys:
|
|
|
|
|
|
if update_key in instance and instance.get(update_key) is not None:
|
|
|
|
|
|
_document.__setattr__(update_key, instance.get(update_key))
|
|
|
|
|
|
_document.save()
|
|
|
|
|
|
return self.one()
|
|
|
|
|
|
|
2023-12-13 10:03:57 +00:00
|
|
|
|
def refresh(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
document_id = self.data.get("document_id")
|
2024-01-03 03:51:48 +00:00
|
|
|
|
document = QuerySet(Document).filter(id=document_id).first()
|
|
|
|
|
|
if document.type == Type.web:
|
2024-01-03 07:40:37 +00:00
|
|
|
|
# 异步同步
|
|
|
|
|
|
work_thread_pool.submit(lambda x: DocumentSerializers.Sync(data={'document_id': document_id}).sync(),
|
|
|
|
|
|
{})
|
2024-01-03 03:51:48 +00:00
|
|
|
|
|
2024-01-03 07:40:37 +00:00
|
|
|
|
else:
|
|
|
|
|
|
ListenerManagement.embedding_by_document_signal.send(document_id)
|
2023-12-13 10:03:57 +00:00
|
|
|
|
return True
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def delete(self):
|
|
|
|
|
|
document_id = self.data.get("document_id")
|
|
|
|
|
|
QuerySet(model=Document).filter(id=document_id).delete()
|
|
|
|
|
|
# 删除段落
|
|
|
|
|
|
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
|
|
|
|
|
# 删除问题
|
2024-03-11 09:28:05 +00:00
|
|
|
|
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
2023-10-24 12:24:32 +00:00
|
|
|
|
# 删除向量库
|
|
|
|
|
|
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_response_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active'
|
|
|
|
|
|
'update_time', 'create_time'],
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
|
|
|
|
|
description="id", default="xx"),
|
|
|
|
|
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
|
2023-12-18 03:32:29 +00:00
|
|
|
|
description="名称", default="测试知识库"),
|
2023-10-24 12:24:32 +00:00
|
|
|
|
'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数",
|
|
|
|
|
|
description="字符数", default=10),
|
|
|
|
|
|
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"),
|
|
|
|
|
|
'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量",
|
|
|
|
|
|
description="文档数量", default=1),
|
|
|
|
|
|
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
|
|
|
|
|
description="是否可用", default=True),
|
|
|
|
|
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
|
|
|
|
|
description="修改时间",
|
|
|
|
|
|
default="1970-01-01 00:00:00"),
|
|
|
|
|
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
|
|
|
|
|
description="创建时间",
|
|
|
|
|
|
default="1970-01-01 00:00:00"
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
|
|
|
|
|
|
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
|
2024-01-19 08:47:18 +00:00
|
|
|
|
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="文档元数据",
|
|
|
|
|
|
description="文档元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
|
2023-10-24 12:24:32 +00:00
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Create(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists():
|
2023-12-18 03:32:29 +00:00
|
|
|
|
raise AppApiException(10000, "知识库id不存在")
|
2023-10-24 12:24:32 +00:00
|
|
|
|
return True
|
|
|
|
|
|
|
2023-12-21 08:55:11 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def post_embedding(result, document_id):
|
|
|
|
|
|
ListenerManagement.embedding_by_document_signal.send(document_id)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
@post(post_function=post_embedding)
|
|
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def save(self, instance: Dict, with_valid=False, **kwargs):
|
2023-10-24 12:24:32 +00:00
|
|
|
|
if with_valid:
|
2023-11-17 09:43:35 +00:00
|
|
|
|
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
2023-12-12 07:44:21 +00:00
|
|
|
|
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
|
|
|
|
|
|
document_model = document_paragraph_model.get('document')
|
|
|
|
|
|
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
|
|
|
|
|
problem_model_list = document_paragraph_model.get('problem_model_list')
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
|
|
|
|
|
|
2023-12-12 07:44:21 +00:00
|
|
|
|
# 插入文档
|
|
|
|
|
|
document_model.save()
|
|
|
|
|
|
# 批量插入段落
|
|
|
|
|
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
|
|
|
|
|
# 批量插入问题
|
|
|
|
|
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
2024-03-11 09:28:05 +00:00
|
|
|
|
# 批量插入关联问题
|
|
|
|
|
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
|
|
|
|
|
problem_paragraph_mapping_list) > 0 else None
|
2023-12-21 08:55:11 +00:00
|
|
|
|
document_id = str(document_model.id)
|
2023-12-12 07:44:21 +00:00
|
|
|
|
return DocumentSerializers.Operate(
|
2023-12-21 08:55:11 +00:00
|
|
|
|
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
|
|
|
|
|
with_valid=True), document_id
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-01-17 08:08:51 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_sync_handler(dataset_id):
|
|
|
|
|
|
def handler(source_url: str, selector, response: Fork.Response):
|
|
|
|
|
|
if response.status == 200:
|
|
|
|
|
|
try:
|
|
|
|
|
|
paragraphs = get_split_model('web.md').parse(response.content)
|
|
|
|
|
|
# 插入
|
|
|
|
|
|
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
|
|
|
|
|
{'name': source_url, 'paragraphs': paragraphs,
|
|
|
|
|
|
'meta': {'source_url': source_url, 'selector': selector},
|
|
|
|
|
|
'type': Type.web}, with_valid=True)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
|
|
|
|
|
else:
|
|
|
|
|
|
Document(name=source_url,
|
|
|
|
|
|
meta={'source_url': source_url, 'selector': selector},
|
|
|
|
|
|
type=Type.web,
|
|
|
|
|
|
char_length=0,
|
|
|
|
|
|
status=Status.error).save()
|
|
|
|
|
|
|
|
|
|
|
|
return handler
|
|
|
|
|
|
|
|
|
|
|
|
def save_web(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
source_url_list = instance.get('source_url_list')
|
|
|
|
|
|
selector = instance.get('selector')
|
|
|
|
|
|
args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id))
|
|
|
|
|
|
ListenerManagement.sync_web_document_signal.send(args)
|
|
|
|
|
|
|
2023-12-12 07:44:21 +00:00
|
|
|
|
@staticmethod
|
2024-01-03 03:51:48 +00:00
|
|
|
|
def get_paragraph_model(document_model, paragraph_list: List):
|
|
|
|
|
|
dataset_id = document_model.dataset_id
|
2023-12-12 07:44:21 +00:00
|
|
|
|
paragraph_model_dict_list = [ParagraphSerializers.Create(
|
|
|
|
|
|
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
|
2024-01-03 03:51:48 +00:00
|
|
|
|
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
|
2023-12-12 07:44:21 +00:00
|
|
|
|
|
|
|
|
|
|
paragraph_model_list = []
|
|
|
|
|
|
problem_model_list = []
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_paragraph_mapping_list = []
|
2023-12-12 07:44:21 +00:00
|
|
|
|
for paragraphs in paragraph_model_dict_list:
|
|
|
|
|
|
paragraph = paragraphs.get('paragraph')
|
|
|
|
|
|
for problem_model in paragraphs.get('problem_model_list'):
|
|
|
|
|
|
problem_model_list.append(problem_model)
|
2024-03-11 09:28:05 +00:00
|
|
|
|
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
|
|
|
|
|
|
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
2023-12-12 07:44:21 +00:00
|
|
|
|
paragraph_model_list.append(paragraph)
|
|
|
|
|
|
|
|
|
|
|
|
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
2024-03-11 09:28:05 +00:00
|
|
|
|
'problem_model_list': problem_model_list,
|
|
|
|
|
|
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-01-03 03:51:48 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_document_paragraph_model(dataset_id, instance: Dict):
|
|
|
|
|
|
document_model = Document(
|
|
|
|
|
|
**{'dataset_id': dataset_id,
|
|
|
|
|
|
'id': uuid.uuid1(),
|
|
|
|
|
|
'name': instance.get('name'),
|
|
|
|
|
|
'char_length': reduce(lambda x, y: x + y,
|
|
|
|
|
|
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
|
|
|
|
|
0),
|
|
|
|
|
|
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
|
|
|
|
|
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
|
|
|
|
|
|
2024-01-17 08:08:51 +00:00
|
|
|
|
return DocumentSerializers.Create.get_paragraph_model(document_model,
|
|
|
|
|
|
instance.get('paragraphs') if
|
|
|
|
|
|
'paragraphs' in instance else [])
|
2024-01-03 03:51:48 +00:00
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return DocumentInstanceSerializer.get_request_body_api()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='dataset_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
2023-12-18 03:32:29 +00:00
|
|
|
|
description='知识库id')
|
2023-10-24 12:24:32 +00:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class Split(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
file = serializers.ListField(required=True, error_messages=ErrMessage.list(
|
|
|
|
|
|
"文件列表"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
limit = serializers.IntegerField(required=False, error_messages=ErrMessage.integer(
|
|
|
|
|
|
"分段长度"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
patterns = serializers.ListField(required=False,
|
2024-03-04 02:12:18 +00:00
|
|
|
|
child=serializers.CharField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"分段标识")),
|
|
|
|
|
|
error_messages=ErrMessage.uuid(
|
|
|
|
|
|
"分段标识列表"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
with_filter = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(
|
|
|
|
|
|
"自动清洗"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=True):
|
2023-11-20 10:53:18 +00:00
|
|
|
|
super().is_valid(raise_exception=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
files = self.data.get('file')
|
|
|
|
|
|
for f in files:
|
|
|
|
|
|
if f.size > 1024 * 1024 * 10:
|
|
|
|
|
|
raise AppApiException(500, "上传文件最大不能超过10m")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [
|
|
|
|
|
|
openapi.Parameter(name='file',
|
|
|
|
|
|
in_=openapi.IN_FORM,
|
|
|
|
|
|
type=openapi.TYPE_ARRAY,
|
|
|
|
|
|
items=openapi.Items(type=openapi.TYPE_FILE),
|
2023-10-09 11:03:41 +00:00
|
|
|
|
required=True,
|
2023-10-24 12:24:32 +00:00
|
|
|
|
description='上传文件'),
|
|
|
|
|
|
openapi.Parameter(name='limit',
|
|
|
|
|
|
in_=openapi.IN_FORM,
|
|
|
|
|
|
required=False,
|
|
|
|
|
|
type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"),
|
|
|
|
|
|
openapi.Parameter(name='patterns',
|
|
|
|
|
|
in_=openapi.IN_FORM,
|
|
|
|
|
|
required=False,
|
|
|
|
|
|
type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING),
|
|
|
|
|
|
title="分段正则列表", description="分段正则列表"),
|
|
|
|
|
|
openapi.Parameter(name='with_filter',
|
|
|
|
|
|
in_=openapi.IN_FORM,
|
|
|
|
|
|
required=False,
|
|
|
|
|
|
type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def parse(self):
|
|
|
|
|
|
file_list = self.data.get("file")
|
2023-11-20 10:53:18 +00:00
|
|
|
|
return list(
|
|
|
|
|
|
map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
|
|
|
|
|
|
self.data.get("limit", None)), file_list))
|
|
|
|
|
|
|
|
|
|
|
|
class SplitPattern(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def list():
|
2024-03-04 10:34:47 +00:00
|
|
|
|
return [{'key': "#", 'value': '(?<=^)# .*|(?<=\\n)# .*'}, {'key': '##', 'value': '(?<!#)## (?!#).*'},
|
|
|
|
|
|
{'key': '###', 'value': "(?<!#)### (?!#).*"}, {'key': '####', 'value': "(?<!#)#### (?!#).*"},
|
|
|
|
|
|
{'key': '#####', 'value': "(?<!#)##### (?!#).*"},
|
|
|
|
|
|
{'key': '######', 'value': "(?<!#)###### (?!#).*"},
|
2023-11-20 10:53:18 +00:00
|
|
|
|
{'key': '-', 'value': '(?<! )- .*'},
|
|
|
|
|
|
{'key': '空格', 'value': '(?<!\\s)\\s(?!\\s)'},
|
|
|
|
|
|
{'key': '分号', 'value': '(?<!;);(?!;)'}, {'key': '逗号', 'value': '(?<!,),(?!,)'},
|
|
|
|
|
|
{'key': '句号', 'value': '(?<!。)。(?!。)'}, {'key': '回车', 'value': '(?<!\\n)\\n(?!\\n)'},
|
|
|
|
|
|
{'key': '空行', 'value': '(?<!\\n)\\n\\n(?!\\n)'}]
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2023-11-17 09:43:35 +00:00
|
|
|
|
class Batch(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
2023-11-17 09:43:35 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
|
|
|
|
|
|
2023-12-21 08:55:11 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def post_embedding(document_list):
|
|
|
|
|
|
for document_dict in document_list:
|
|
|
|
|
|
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
|
|
|
|
|
|
return document_list
|
|
|
|
|
|
|
|
|
|
|
|
@post(post_function=post_embedding)
|
|
|
|
|
|
@transaction.atomic
|
2023-11-17 09:43:35 +00:00
|
|
|
|
def batch_save(self, instance_list: List[Dict], with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
|
2023-12-21 08:55:11 +00:00
|
|
|
|
dataset_id = self.data.get("dataset_id")
|
|
|
|
|
|
document_model_list = []
|
|
|
|
|
|
paragraph_model_list = []
|
|
|
|
|
|
problem_model_list = []
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_paragraph_mapping_list = []
|
2023-12-21 08:55:11 +00:00
|
|
|
|
# 插入文档
|
|
|
|
|
|
for document in instance_list:
|
|
|
|
|
|
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
|
|
|
|
|
document)
|
|
|
|
|
|
document_model_list.append(document_paragraph_dict_model.get('document'))
|
|
|
|
|
|
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
|
|
|
|
|
paragraph_model_list.append(paragraph)
|
|
|
|
|
|
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
|
|
|
|
|
problem_model_list.append(problem)
|
2024-03-11 09:28:05 +00:00
|
|
|
|
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
|
|
|
|
|
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
2023-12-21 08:55:11 +00:00
|
|
|
|
|
|
|
|
|
|
# 插入文档
|
|
|
|
|
|
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
|
|
|
|
|
# 批量插入段落
|
|
|
|
|
|
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
|
|
|
|
|
# 批量插入问题
|
|
|
|
|
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
2024-03-11 09:28:05 +00:00
|
|
|
|
# 批量插入关联问题
|
|
|
|
|
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
|
|
|
|
|
problem_paragraph_mapping_list) > 0 else None
|
2023-12-21 08:55:11 +00:00
|
|
|
|
# 查询文档
|
|
|
|
|
|
query_set = QuerySet(model=Document)
|
|
|
|
|
|
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
|
|
|
|
|
return native_search(query_set, select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
2023-11-17 09:43:35 +00:00
|
|
|
|
|
2024-01-17 08:08:51 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _batch_sync(document_id_list: List[str]):
|
|
|
|
|
|
for document_id in document_id_list:
|
|
|
|
|
|
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
|
|
|
|
|
|
|
|
|
|
|
|
def batch_sync(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
# 异步同步
|
|
|
|
|
|
work_thread_pool.submit(self._batch_sync,
|
|
|
|
|
|
instance.get('id_list'))
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def batch_delete(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2024-01-23 07:52:15 +00:00
|
|
|
|
document_id_list = instance.get("id_list")
|
|
|
|
|
|
QuerySet(Document).filter(id__in=document_id_list).delete()
|
|
|
|
|
|
QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
|
|
|
|
|
|
QuerySet(Problem).filter(document_id__in=document_id_list).delete()
|
|
|
|
|
|
# 删除向量库
|
|
|
|
|
|
ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list)
|
2024-01-17 08:08:51 +00:00
|
|
|
|
return True
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2023-11-20 10:53:18 +00:00
|
|
|
|
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
|
2023-10-24 12:24:32 +00:00
|
|
|
|
data = file.read()
|
2023-11-20 10:53:18 +00:00
|
|
|
|
if pattern_list is not None and len(pattern_list) > 0:
|
2023-10-24 12:24:32 +00:00
|
|
|
|
split_model = SplitModel(pattern_list, with_filter, limit)
|
|
|
|
|
|
else:
|
2023-11-20 10:53:18 +00:00
|
|
|
|
split_model = get_split_model(file.name, with_filter=with_filter, limit=limit)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
try:
|
|
|
|
|
|
content = data.decode('utf-8')
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
|
return {'name': file.name,
|
|
|
|
|
|
'content': []}
|
|
|
|
|
|
return {'name': file.name,
|
|
|
|
|
|
'content': split_model.parse(content)
|
|
|
|
|
|
}
|