2023-10-24 12:24:32 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: problem_serializers.py
|
|
|
|
|
|
@date:2023/10/23 13:55
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
2024-03-11 09:28:05 +00:00
|
|
|
|
import os
|
2023-10-24 12:24:32 +00:00
|
|
|
|
import uuid
|
2024-09-20 09:07:52 +00:00
|
|
|
|
from functools import reduce
|
2024-04-09 03:33:28 +00:00
|
|
|
|
from typing import Dict, List
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-01-23 07:52:15 +00:00
|
|
|
|
from django.db import transaction
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from django.db.models import QuerySet
|
|
|
|
|
|
from drf_yasg import openapi
|
|
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
from common.db.search import native_search, native_page_search
|
2023-10-24 12:24:32 +00:00
|
|
|
|
from common.mixins.api_mixin import ApiMixin
|
2024-03-04 02:12:18 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2024-03-11 09:28:05 +00:00
|
|
|
|
from common.util.file_util import get_file_content
|
2024-07-18 07:44:48 +00:00
|
|
|
|
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
|
2024-08-21 06:46:11 +00:00
|
|
|
|
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
|
2024-09-20 09:07:52 +00:00
|
|
|
|
from embedding.models import SourceType
|
|
|
|
|
|
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
|
2024-03-11 09:28:05 +00:00
|
|
|
|
from smartdoc.conf import PROJECT_DIR
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProblemSerializer(serializers.ModelSerializer):
|
|
|
|
|
|
class Meta:
|
|
|
|
|
|
model = Problem
|
2024-03-11 09:28:05 +00:00
|
|
|
|
fields = ['id', 'content', 'dataset_id',
|
2023-10-24 12:24:32 +00:00
|
|
|
|
'create_time', 'update_time']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
id = serializers.CharField(required=False, error_messages=ErrMessage.char("问题id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-06 10:37:11 +00:00
|
|
|
|
content = serializers.CharField(required=True, max_length=256, error_messages=ErrMessage.char("问题内容"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_request_body_api():
|
|
|
|
|
|
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
|
|
|
|
|
required=["content"],
|
|
|
|
|
|
properties={
|
|
|
|
|
|
'id': openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_STRING,
|
|
|
|
|
|
title="问题id,修改的时候传递,创建的时候不传"),
|
|
|
|
|
|
'content': openapi.Schema(
|
|
|
|
|
|
type=openapi.TYPE_STRING, title="内容")
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-09-20 09:07:52 +00:00
|
|
|
|
class AssociationParagraph(serializers.Serializer):
|
|
|
|
|
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
|
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchAssociation(serializers.Serializer):
|
|
|
|
|
|
problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题id列表"),
|
|
|
|
|
|
child=serializers.UUIDField(required=True,
|
|
|
|
|
|
error_messages=ErrMessage.uuid("问题id")))
|
|
|
|
|
|
paragraph_list = AssociationParagraph(many=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
|
|
|
|
|
|
filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
|
|
|
|
|
|
exits_problem_paragraph_mapping_list if
|
|
|
|
|
|
str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
|
|
|
|
|
|
and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
|
|
|
|
|
|
and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id]
|
|
|
|
|
|
return len(filter_list) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str):
|
|
|
|
|
|
return ProblemParagraphMapping(id=uuid.uuid1(),
|
|
|
|
|
|
document_id=document_id,
|
|
|
|
|
|
paragraph_id=paragraph_id,
|
|
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
|
|
problem_id=str(problem.id)), problem
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-10-24 12:24:32 +00:00
|
|
|
|
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
2024-03-11 09:28:05 +00:00
|
|
|
|
class Create(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题列表"),
|
|
|
|
|
|
child=serializers.CharField(required=True,
|
2024-09-25 10:35:23 +00:00
|
|
|
|
max_length=256,
|
2024-03-11 09:28:05 +00:00
|
|
|
|
error_messages=ErrMessage.char("问题")))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
def batch(self, with_valid=True):
|
2023-10-24 12:24:32 +00:00
|
|
|
|
if with_valid:
|
2024-03-11 09:28:05 +00:00
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
problem_list = self.data.get('problem_list')
|
2024-04-10 06:16:56 +00:00
|
|
|
|
problem_list = list(set(problem_list))
|
2024-03-11 09:28:05 +00:00
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
exists_problem_content_list = [problem.content for problem in
|
|
|
|
|
|
QuerySet(Problem).filter(dataset_id=dataset_id,
|
|
|
|
|
|
content__in=problem_list)]
|
|
|
|
|
|
problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for
|
|
|
|
|
|
problem_content in
|
2024-04-10 06:16:56 +00:00
|
|
|
|
problem_list if
|
2024-03-11 09:28:05 +00:00
|
|
|
|
(not exists_problem_content_list.__contains__(problem_content) if
|
|
|
|
|
|
len(exists_problem_content_list) > 0 else True)]
|
|
|
|
|
|
|
|
|
|
|
|
QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
|
|
|
|
|
|
return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
|
|
|
|
|
|
|
|
|
|
|
|
class Query(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
2024-03-11 09:28:05 +00:00
|
|
|
|
content = serializers.CharField(required=False, error_messages=ErrMessage.char("问题"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
|
|
|
|
|
def get_query_set(self):
|
2024-03-11 09:28:05 +00:00
|
|
|
|
query_set = QuerySet(model=Problem)
|
|
|
|
|
|
query_set = query_set.filter(
|
|
|
|
|
|
**{'dataset_id': self.data.get('dataset_id')})
|
|
|
|
|
|
if 'content' in self.data:
|
2024-04-15 05:58:40 +00:00
|
|
|
|
query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
|
2024-04-09 03:33:28 +00:00
|
|
|
|
query_set = query_set.order_by("-create_time")
|
2024-03-11 09:28:05 +00:00
|
|
|
|
return query_set
|
|
|
|
|
|
|
|
|
|
|
|
def list(self):
|
2023-10-24 12:24:32 +00:00
|
|
|
|
query_set = self.get_query_set()
|
2024-03-11 09:28:05 +00:00
|
|
|
|
return native_search(query_set, select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
def page(self, current_page, page_size):
|
|
|
|
|
|
query_set = self.get_query_set()
|
|
|
|
|
|
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-04-09 03:33:28 +00:00
|
|
|
|
class BatchOperate(serializers.Serializer):
|
|
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
|
|
|
|
|
|
|
|
|
|
|
def delete(self, problem_id_list: List, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
|
|
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
|
|
problem_id__in=problem_id_list)
|
|
|
|
|
|
source_ids = [row.id for row in problem_paragraph_mapping_list]
|
|
|
|
|
|
problem_paragraph_mapping_list.delete()
|
|
|
|
|
|
QuerySet(Problem).filter(id__in=problem_id_list).delete()
|
2024-08-21 06:46:11 +00:00
|
|
|
|
delete_embedding_by_source_ids(source_ids)
|
2024-04-09 03:33:28 +00:00
|
|
|
|
return True
|
|
|
|
|
|
|
2024-09-20 09:07:52 +00:00
|
|
|
|
def association(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
BatchAssociation(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
paragraph_list = instance.get('paragraph_list')
|
|
|
|
|
|
problem_id_list = instance.get('problem_id_list')
|
|
|
|
|
|
problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
|
|
|
|
|
|
exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list,
|
|
|
|
|
|
paragraph_id__in=[
|
|
|
|
|
|
p.get('paragraph_id')
|
|
|
|
|
|
for p in
|
|
|
|
|
|
paragraph_list])
|
|
|
|
|
|
problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for
|
|
|
|
|
|
problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y],
|
|
|
|
|
|
[[
|
|
|
|
|
|
to_problem_paragraph_mapping(
|
|
|
|
|
|
problem,
|
|
|
|
|
|
paragraph.get(
|
|
|
|
|
|
'document_id'),
|
|
|
|
|
|
paragraph.get(
|
|
|
|
|
|
'paragraph_id'),
|
|
|
|
|
|
dataset_id) for
|
|
|
|
|
|
paragraph in
|
|
|
|
|
|
paragraph_list]
|
|
|
|
|
|
for problem in
|
|
|
|
|
|
problem_list], []) if
|
|
|
|
|
|
not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)]
|
|
|
|
|
|
QuerySet(ProblemParagraphMapping).bulk_create(
|
|
|
|
|
|
[problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list])
|
|
|
|
|
|
data_list = [{'text': problem.content,
|
|
|
|
|
|
'is_active': True,
|
|
|
|
|
|
'source_type': SourceType.PROBLEM,
|
|
|
|
|
|
'source_id': str(problem_paragraph_mapping.id),
|
|
|
|
|
|
'document_id': str(problem_paragraph_mapping.document_id),
|
|
|
|
|
|
'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
|
|
|
|
|
|
'dataset_id': dataset_id,
|
|
|
|
|
|
} for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
|
|
|
|
|
|
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
|
|
|
|
|
|
embedding_by_data_list(data_list, model_id=model_id)
|
|
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
class Operate(serializers.Serializer):
|
|
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
2023-10-24 12:24:32 +00:00
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
def list_paragraph(self, with_valid=True):
|
2023-10-24 12:24:32 +00:00
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
|
|
|
|
|
|
problem_id=self.data.get("problem_id"))
|
2024-04-10 06:16:56 +00:00
|
|
|
|
if problem_paragraph_mapping is None or len(problem_paragraph_mapping) == 0:
|
2024-04-09 03:33:28 +00:00
|
|
|
|
return []
|
2024-03-11 09:28:05 +00:00
|
|
|
|
return native_search(
|
|
|
|
|
|
QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
|
|
|
|
|
|
select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
|
|
|
|
|
|
|
|
|
|
|
def one(self, with_valid=True):
|
2023-10-24 12:24:32 +00:00
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
|
|
|
|
|
|
|
2024-03-11 09:28:05 +00:00
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def delete(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id'),
|
|
|
|
|
|
problem_id=self.data.get('problem_id'))
|
|
|
|
|
|
source_ids = [row.id for row in problem_paragraph_mapping_list]
|
2024-04-09 03:33:28 +00:00
|
|
|
|
problem_paragraph_mapping_list.delete()
|
2024-03-11 09:28:05 +00:00
|
|
|
|
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
|
2024-08-21 06:46:11 +00:00
|
|
|
|
delete_embedding_by_source_ids(source_ids)
|
2024-03-11 09:28:05 +00:00
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def edit(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
problem_id = self.data.get('problem_id')
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
content = instance.get('content')
|
|
|
|
|
|
problem = QuerySet(Problem).filter(id=problem_id,
|
|
|
|
|
|
dataset_id=dataset_id).first()
|
2024-07-18 07:44:48 +00:00
|
|
|
|
QuerySet(DataSet).filter(id=dataset_id)
|
2024-03-11 09:28:05 +00:00
|
|
|
|
problem.content = content
|
|
|
|
|
|
problem.save()
|
2024-08-21 06:46:11 +00:00
|
|
|
|
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
|
|
|
|
|
update_problem_embedding(problem_id, content, model_id)
|