2023-10-24 12:24:32 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: paragraph_serializers.py
|
|
|
|
|
|
@date:2023/10/16 15:51
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
|
|
|
|
from django.db import transaction
|
|
|
|
|
|
from django.db.models import QuerySet
|
|
|
|
|
|
from drf_yasg import openapi
|
|
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
|
|
|
|
|
from common.db.search import page_search
|
|
|
|
|
|
from common.event.listener_manage import ListenerManagement
|
|
|
|
|
|
from common.exception.app_exception import AppApiException
|
|
|
|
|
|
from common.mixins.api_mixin import ApiMixin
|
2023-12-21 08:55:11 +00:00
|
|
|
|
from common.util.common import post
|
2024-03-04 02:12:18 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2024-03-11 09:28:05 +00:00
|
|
|
|
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
2023-11-17 09:43:35 +00:00
|
|
|
|
from dataset.serializers.common_serializers import update_document_char_length
|
2024-03-11 09:28:05 +00:00
|
|
|
|
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
|
|
|
|
|
from embedding.models import SourceType
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParagraphSerializer(serializers.ModelSerializer):
|
|
|
|
|
|
class Meta:
|
|
|
|
|
|
model = Paragraph
|
2024-01-16 08:46:54 +00:00
|
|
|
|
fields = ['id', 'content', 'is_active', 'document_id', 'title',
|
2023-10-24 12:24:32 +00:00
|
|
|
|
'create_time', 'update_time']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
"""
|
|
|
|
|
|
段落实例对象
|
|
|
|
|
|
"""
|
2024-03-04 02:12:18 +00:00
|
|
|
|
content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"),
|
|
|
|
|
|
max_length=4096,
|
|
|
|
|
|
min_length=1,
|
|
|
|
|
|
allow_null=True, allow_blank=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-06 10:37:11 +00:00
|
|
|
|
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char("段落标题"),
|
2024-03-04 02:12:18 +00:00
|
|
|
|
allow_null=True, allow_blank=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
problem_list = ProblemInstanceSerializer(required=False, many=True)
|
|
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("段落是否可用"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
required=['content'],
|
|
|
|
|
|
properties={
|
2024-03-04 03:00:37 +00:00
|
|
|
|
'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容",
|
|
|
|
|
|
description="分段内容"),
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 03:00:37 +00:00
|
|
|
|
'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题",
|
2023-10-24 12:24:32 +00:00
|
|
|
|
description="分段标题"),
|
|
|
|
|
|
|
|
|
|
|
|
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
|
|
|
|
|
|
|
|
|
|
|
|
'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表',
|
|
|
|
|
|
description="问题列表",
|
|
|
|
|
|
items=ProblemInstanceSerializer.get_request_body_api())
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-03-06 10:24:58 +00:00
|
|
|
|
class EditParagraphSerializers(serializers.Serializer):
|
|
|
|
|
|
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char(
|
|
|
|
|
|
"分段标题"), allow_null=True, allow_blank=True)
|
|
|
|
|
|
content = serializers.CharField(required=False, max_length=4096, allow_null=True, allow_blank=True,
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
|
|
|
|
|
"分段内容"))
|
|
|
|
|
|
problem_list = ProblemInstanceSerializer(required=False, many=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
2024-03-04 03:00:37 +00:00
|
|
|
|
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char(
|
|
|
|
|
|
"分段标题"), allow_null=True, allow_blank=True)
|
|
|
|
|
|
content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char(
|
|
|
|
|
|
"分段内容"))
|
|
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
class Problem(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
|
|
|
|
|
|
|
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
|
|
|
|
|
|
|
|
|
|
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
|
|
|
|
|
raise AppApiException(500, "段落id不存在")
|
|
|
|
|
|
|
|
|
|
|
|
def list(self, with_valid=False):
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取问题列表
|
|
|
|
|
|
:param with_valid: 是否校验
|
|
|
|
|
|
:return: 问题列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
|
|
|
|
|
|
paragraph_id=self.data.get(
|
|
|
|
|
|
'paragraph_id'))
|
|
|
|
|
|
return [ProblemSerializer(row).data for row in
|
|
|
|
|
|
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
|
|
|
|
|
|
|
|
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid()
|
|
|
|
|
|
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'),
|
|
|
|
|
|
content=instance.get('content')).first()
|
|
|
|
|
|
if problem is None:
|
|
|
|
|
|
problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'),
|
|
|
|
|
|
content=instance.get('content'))
|
|
|
|
|
|
problem.save()
|
|
|
|
|
|
if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id,
|
|
|
|
|
|
paragraph_id=self.data.get('paragraph_id')).exists():
|
|
|
|
|
|
raise AppApiException(500, "已经关联,请勿重复关联")
|
|
|
|
|
|
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
|
|
|
|
|
|
problem_id=problem.id,
|
|
|
|
|
|
document_id=self.data.get('document_id'),
|
|
|
|
|
|
paragraph_id=self.data.get('paragraph_id'),
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id'))
|
|
|
|
|
|
problem_paragraph_mapping.save()
|
|
|
|
|
|
if with_embedding:
|
|
|
|
|
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
|
|
|
|
|
'is_active': True,
|
|
|
|
|
|
'source_type': SourceType.PROBLEM,
|
|
|
|
|
|
'source_id': problem_paragraph_mapping.id,
|
|
|
|
|
|
'document_id': self.data.get('document_id'),
|
|
|
|
|
|
'paragraph_id': self.data.get('paragraph_id'),
|
|
|
|
|
|
'dataset_id': self.data.get('dataset_id'),
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return ProblemSerializers.Operate(
|
|
|
|
|
|
data={'dataset_id': self.data.get('dataset_id'),
|
|
|
|
|
|
'problem_id': problem.id}).one(with_valid=True)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='dataset_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='知识库id'),
|
|
|
|
|
|
openapi.Parameter(name='document_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='文档id'),
|
|
|
|
|
|
openapi.Parameter(name='paragraph_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='段落id')]
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
required=["content"],
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'content': openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_STRING, title="内容")
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_response_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
|
|
|
|
|
description="id", default="xx"),
|
|
|
|
|
|
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
|
|
|
|
|
description="问题内容", default='问题内容'),
|
|
|
|
|
|
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
|
|
|
|
|
default=1),
|
|
|
|
|
|
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
|
|
|
|
|
description="知识库id", default='xxx'),
|
|
|
|
|
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
|
|
|
|
|
description="修改时间",
|
|
|
|
|
|
default="1970-01-01 00:00:00"),
|
|
|
|
|
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
|
|
|
|
|
description="创建时间",
|
|
|
|
|
|
default="1970-01-01 00:00:00"
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Association(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
|
|
|
|
|
|
|
|
|
|
|
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
|
|
|
|
|
|
|
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
|
|
|
|
|
|
|
|
|
|
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=True):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
paragraph_id = self.data.get('paragraph_id')
|
|
|
|
|
|
problem_id = self.data.get("problem_id")
|
|
|
|
|
|
if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists():
|
|
|
|
|
|
raise AppApiException(500, "段落不存在")
|
|
|
|
|
|
if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists():
|
|
|
|
|
|
raise AppApiException(500, "问题不存在")
|
|
|
|
|
|
|
|
|
|
|
|
def association(self, with_valid=True, with_embedding=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
problem = QuerySet(Problem).filter(id=self.data.get("problem_id"))
|
|
|
|
|
|
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
|
|
|
|
|
|
document_id=self.data.get('document_id'),
|
|
|
|
|
|
paragraph_id=self.data.get('paragraph_id'),
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id'),
|
|
|
|
|
|
problem_id=problem.id)
|
|
|
|
|
|
problem_paragraph_mapping.save()
|
|
|
|
|
|
if with_embedding:
|
|
|
|
|
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
|
|
|
|
|
'is_active': True,
|
|
|
|
|
|
'source_type': SourceType.PROBLEM,
|
|
|
|
|
|
'source_id': problem_paragraph_mapping.id,
|
|
|
|
|
|
'document_id': self.data.get('document_id'),
|
|
|
|
|
|
'paragraph_id': self.data.get('paragraph_id'),
|
|
|
|
|
|
'dataset_id': self.data.get('dataset_id'),
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
def un_association(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
|
|
|
|
|
paragraph_id=self.data.get('paragraph_id'),
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id'),
|
|
|
|
|
|
problem_id=self.data.get(
|
|
|
|
|
|
'problem_id')).first()
|
|
|
|
|
|
problem_paragraph_mapping_id = problem_paragraph_mapping.id
|
|
|
|
|
|
problem_paragraph_mapping.delete()
|
|
|
|
|
|
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='dataset_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='知识库id'),
|
|
|
|
|
|
openapi.Parameter(name='document_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='文档id')
|
|
|
|
|
|
, openapi.Parameter(name='paragraph_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='段落id'),
|
|
|
|
|
|
openapi.Parameter(name='problem_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='问题id')
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
class Operate(ApiMixin, serializers.Serializer):
|
|
|
|
|
|
# 段落id
|
2024-03-04 02:12:18 +00:00
|
|
|
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"段落id"))
|
2023-12-18 03:32:29 +00:00
|
|
|
|
# 知识库id
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"知识库id"))
|
|
|
|
|
|
# 文档id
|
|
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=True):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
|
|
|
|
|
raise AppApiException(500, "段落id不存在")
|
|
|
|
|
|
|
2023-12-21 08:55:11 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def post_embedding(paragraph, instance):
|
|
|
|
|
|
if 'is_active' in instance and instance.get('is_active') is not None:
|
|
|
|
|
|
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
|
|
|
|
|
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
|
|
|
|
|
s.send(paragraph.get('id'))
|
|
|
|
|
|
else:
|
|
|
|
|
|
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
|
|
|
|
|
|
return paragraph
|
|
|
|
|
|
|
|
|
|
|
|
@post(post_embedding)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def edit(self, instance: Dict):
|
|
|
|
|
|
self.is_valid()
|
2024-03-06 10:24:58 +00:00
|
|
|
|
EditParagraphSerializers(data=instance).is_valid(raise_exception=True)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
|
|
|
|
|
|
update_keys = ['title', 'content', 'is_active']
|
|
|
|
|
|
for update_key in update_keys:
|
|
|
|
|
|
if update_key in instance and instance.get(update_key) is not None:
|
|
|
|
|
|
_paragraph.__setattr__(update_key, instance.get(update_key))
|
|
|
|
|
|
|
|
|
|
|
|
if 'problem_list' in instance:
|
|
|
|
|
|
update_problem_list = list(
|
|
|
|
|
|
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
|
|
|
|
|
|
|
|
|
|
|
|
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
|
|
|
|
|
|
|
|
|
|
|
|
# 问题集合
|
|
|
|
|
|
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
|
|
|
|
|
|
|
|
|
|
|
|
# 校验前端 携带过来的id
|
|
|
|
|
|
for update_problem in update_problem_list:
|
|
|
|
|
|
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
|
|
|
|
|
|
raise AppApiException(500, update_problem.get('id') + '问题id不存在')
|
|
|
|
|
|
# 对比需要删除的问题
|
|
|
|
|
|
delete_problem_list = list(filter(
|
|
|
|
|
|
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
|
|
|
|
|
|
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
|
|
|
|
|
|
# 删除问题
|
|
|
|
|
|
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
|
|
|
|
|
|
delete_problem_list) > 0 else None
|
|
|
|
|
|
# 插入新的问题
|
|
|
|
|
|
QuerySet(Problem).bulk_create(
|
|
|
|
|
|
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
|
|
|
|
|
|
p in create_problem_list]) if len(create_problem_list) else None
|
|
|
|
|
|
|
|
|
|
|
|
# 修改问题集合
|
|
|
|
|
|
QuerySet(Problem).bulk_update(
|
|
|
|
|
|
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
|
|
|
|
|
|
['content']) if len(
|
|
|
|
|
|
update_problem_list) > 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
_paragraph.save()
|
2023-11-17 09:43:35 +00:00
|
|
|
|
update_document_char_length(self.data.get('document_id'))
|
2023-12-21 08:55:11 +00:00
|
|
|
|
return self.one(), instance
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def get_problem_list(self):
|
2024-03-11 09:28:05 +00:00
|
|
|
|
ProblemParagraphMapping(ProblemParagraphMapping)
|
|
|
|
|
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
|
|
|
|
|
paragraph_id=self.data.get("paragraph_id"))
|
|
|
|
|
|
if len(problem_paragraph_mapping) > 0:
|
|
|
|
|
|
return [ProblemSerializer(problem).data for problem in
|
|
|
|
|
|
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
|
|
|
|
|
|
return []
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def one(self, with_valid=False):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
|
|
|
|
|
|
'problem_list': self.get_problem_list()}
|
|
|
|
|
|
|
|
|
|
|
|
def delete(self, with_valid=False):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
paragraph_id = self.data.get('paragraph_id')
|
|
|
|
|
|
QuerySet(Paragraph).filter(id=paragraph_id).delete()
|
2024-03-11 09:28:05 +00:00
|
|
|
|
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
|
2023-10-24 12:24:32 +00:00
|
|
|
|
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return ParagraphInstanceSerializer.get_request_body_api()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_response_body_api():
|
|
|
|
|
|
return ParagraphInstanceSerializer.get_request_body_api()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id',
|
|
|
|
|
|
description="段落id")]
|
|
|
|
|
|
|
|
|
|
|
|
class Create(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"知识库id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if not QuerySet(Document).filter(id=self.data.get('document_id'),
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id')).exists():
|
|
|
|
|
|
raise AppApiException(500, "文档id不正确")
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
self.is_valid()
|
|
|
|
|
|
dataset_id = self.data.get("dataset_id")
|
|
|
|
|
|
document_id = self.data.get('document_id')
|
2023-12-12 07:44:21 +00:00
|
|
|
|
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
|
|
|
|
|
paragraph = paragraph_problem_model.get('paragraph')
|
|
|
|
|
|
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
|
2023-10-24 12:24:32 +00:00
|
|
|
|
# 插入段落
|
2023-12-12 07:44:21 +00:00
|
|
|
|
paragraph_problem_model.get('paragraph').save()
|
2023-10-24 12:24:32 +00:00
|
|
|
|
# 插入問題
|
|
|
|
|
|
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
2024-03-11 09:28:05 +00:00
|
|
|
|
# 插入问题关联关系
|
|
|
|
|
|
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
|
|
|
|
|
problem_paragraph_mapping_list) > 0 else None
|
2023-11-17 09:43:35 +00:00
|
|
|
|
# 修改长度
|
|
|
|
|
|
update_document_char_length(document_id)
|
2023-10-24 12:24:32 +00:00
|
|
|
|
if with_embedding:
|
|
|
|
|
|
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
|
|
|
|
|
return ParagraphSerializers.Operate(
|
|
|
|
|
|
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
|
|
|
|
|
with_valid=True)
|
|
|
|
|
|
|
2023-12-12 07:44:21 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict):
|
|
|
|
|
|
paragraph = Paragraph(id=uuid.uuid1(),
|
|
|
|
|
|
document_id=document_id,
|
|
|
|
|
|
content=instance.get("content"),
|
|
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
|
|
title=instance.get("title") if 'title' in instance else '')
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_list = instance.get('problem_list')
|
|
|
|
|
|
exists_problem_list = []
|
|
|
|
|
|
if 'problem_list' in instance and len(problem_list) > 0:
|
|
|
|
|
|
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
|
|
|
|
|
|
content__in=[p.get('content') for p in
|
|
|
|
|
|
problem_list]).all()
|
|
|
|
|
|
|
|
|
|
|
|
problem_model_list = [
|
|
|
|
|
|
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
|
|
|
|
|
|
problem in (
|
|
|
|
|
|
instance.get('problem_list') if 'problem_list' in instance else [])]
|
|
|
|
|
|
|
|
|
|
|
|
problem_paragraph_mapping_list = [
|
|
|
|
|
|
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
|
|
|
|
|
paragraph_id=paragraph.id,
|
|
|
|
|
|
dataset_id=dataset_id) for
|
|
|
|
|
|
problem_model in problem_model_list]
|
|
|
|
|
|
return {'paragraph': paragraph,
|
|
|
|
|
|
'problem_model_list': [problem_model for problem_model in problem_model_list if
|
|
|
|
|
|
not list(exists_problem_list).__contains__(problem_model)],
|
|
|
|
|
|
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
2023-12-12 07:44:21 +00:00
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def or_get(exists_problem_list, content, dataset_id):
|
|
|
|
|
|
exists = [row for row in exists_problem_list if row.content == content]
|
|
|
|
|
|
if len(exists) > 0:
|
|
|
|
|
|
return exists[0]
|
|
|
|
|
|
else:
|
|
|
|
|
|
return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
|
2023-12-12 07:44:21 +00:00
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return ParagraphInstanceSerializer.get_request_body_api()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='dataset_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
2023-12-18 03:32:29 +00:00
|
|
|
|
description='知识库id'),
|
2023-10-24 12:24:32 +00:00
|
|
|
|
openapi.Parameter(name='document_id', in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description="文档id")
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class Query(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"知识库id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
|
|
|
|
|
"文档id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
title = serializers.CharField(required=False, error_messages=ErrMessage.char(
|
|
|
|
|
|
"段落标题"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2023-11-17 09:43:35 +00:00
|
|
|
|
content = serializers.CharField(required=False)
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
def get_query_set(self):
|
|
|
|
|
|
query_set = QuerySet(model=Paragraph)
|
|
|
|
|
|
query_set = query_set.filter(
|
|
|
|
|
|
**{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")})
|
|
|
|
|
|
if 'title' in self.data:
|
|
|
|
|
|
query_set = query_set.filter(
|
|
|
|
|
|
**{'title__contains': self.data.get('title')})
|
2023-11-17 09:43:35 +00:00
|
|
|
|
if 'content' in self.data:
|
|
|
|
|
|
query_set = query_set.filter(**{'content__contains': self.data.get('content')})
|
2023-10-24 12:24:32 +00:00
|
|
|
|
return query_set
|
|
|
|
|
|
|
|
|
|
|
|
def list(self):
|
|
|
|
|
|
return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
|
|
|
|
|
|
|
|
|
|
|
|
def page(self, current_page, page_size):
|
|
|
|
|
|
query_set = self.get_query_set()
|
|
|
|
|
|
return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_params_api():
|
|
|
|
|
|
return [openapi.Parameter(name='document_id',
|
|
|
|
|
|
in_=openapi.IN_PATH,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=True,
|
|
|
|
|
|
description='文档id'),
|
|
|
|
|
|
openapi.Parameter(name='title',
|
|
|
|
|
|
in_=openapi.IN_QUERY,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=False,
|
2023-11-17 09:43:35 +00:00
|
|
|
|
description='标题'),
|
|
|
|
|
|
openapi.Parameter(name='content',
|
|
|
|
|
|
in_=openapi.IN_QUERY,
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
required=False,
|
|
|
|
|
|
description='内容')
|
2023-10-24 12:24:32 +00:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_response_body_api():
|
|
|
|
|
|
return openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_OBJECT,
|
2023-12-13 09:14:43 +00:00
|
|
|
|
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
|
|
|
|
|
|
'document_id', 'title',
|
2023-10-24 12:24:32 +00:00
|
|
|
|
'create_time', 'update_time'],
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
|
|
|
|
|
description="id", default="xx"),
|
|
|
|
|
|
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
|
|
|
|
|
|
description="段落内容", default='段落内容'),
|
|
|
|
|
|
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
|
|
|
|
|
|
description="标题", default="xxx的描述"),
|
|
|
|
|
|
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
|
|
|
|
|
default=1),
|
|
|
|
|
|
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
|
|
|
|
|
description="点赞数量", default=1),
|
|
|
|
|
|
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
|
|
|
|
|
description="点踩数", default=1),
|
2023-12-18 03:32:29 +00:00
|
|
|
|
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
|
|
|
|
|
description="知识库id", default='xxx'),
|
2023-10-24 12:24:32 +00:00
|
|
|
|
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
|
|
|
|
|
description="文档id", default='xxx'),
|
|
|
|
|
|
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
|
|
|
|
|
description="是否可用", default=True),
|
|
|
|
|
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
|
|
|
|
|
description="修改时间",
|
|
|
|
|
|
default="1970-01-01 00:00:00"),
|
|
|
|
|
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
|
|
|
|
|
description="创建时间",
|
|
|
|
|
|
default="1970-01-01 00:00:00"
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|