UnisKB/apps/dataset/serializers/paragraph_serializers.py

740 lines
40 KiB
Python
Raw Normal View History

# coding=utf-8
"""
@project: maxkb
@Author
@file paragraph_serializers.py
@date2023/10/16 15:51
@desc:
"""
import uuid
from typing import Dict
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
2024-03-04 02:12:18 +00:00
from common.util.field_message import ErrMessage
2024-07-18 07:44:48 +00:00
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
2024-08-21 06:46:11 +00:00
ProblemParagraphManage, get_embedding_model_id_by_dataset_id
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
2024-08-21 06:46:11 +00:00
from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \
delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \
delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id
from dataset.task import generate_related_by_paragraph_id_list
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
2024-01-16 08:46:54 +00:00
fields = ['id', 'content', 'is_active', 'document_id', 'title',
'create_time', 'update_time']
class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
"""
段落实例对象
"""
2024-03-04 02:12:18 +00:00
content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"),
max_length=102400,
2024-03-04 02:12:18 +00:00
min_length=1,
allow_null=True, allow_blank=True)
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char("段落标题"),
2024-03-04 02:12:18 +00:00
allow_null=True, allow_blank=True)
problem_list = ProblemInstanceSerializer(required=False, many=True)
2024-03-04 02:12:18 +00:00
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("段落是否可用"))
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容",
description="分段内容"),
'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题",
description="分段标题"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表',
description="问题列表",
items=ProblemInstanceSerializer.get_request_body_api())
}
)
class EditParagraphSerializers(serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char(
"分段标题"), allow_null=True, allow_blank=True)
content = serializers.CharField(required=False, max_length=102400, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char(
"分段内容"))
problem_list = ProblemInstanceSerializer(required=False, many=True)
class ParagraphSerializers(ApiMixin, serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char(
"分段标题"), allow_null=True, allow_blank=True)
content = serializers.CharField(required=True, max_length=102400, error_messages=ErrMessage.char(
"分段内容"))
class Problem(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
paragraph_id=self.data.get(
'paragraph_id'))
return [ProblemSerializer(row).data for row in
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
@transaction.atomic
2024-08-21 06:46:11 +00:00
def save(self, instance: Dict, with_valid=True, with_embedding=True, embedding_by_problem=None):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'),
content=instance.get('content')).first()
if problem is None:
problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id,
paragraph_id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "已经关联,请勿重复关联")
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
problem_id=problem.id,
document_id=self.data.get('document_id'),
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save()
2024-08-21 06:46:11 +00:00
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
if with_embedding:
2024-08-21 06:46:11 +00:00
embedding_by_problem_task({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
}, model_id)
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'),
'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'content': openapi.Schema(
type=openapi.TYPE_STRING, title="内容")
})
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
class Association(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
paragraph_id = self.data.get('paragraph_id')
problem_id = self.data.get("problem_id")
if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists():
raise AppApiException(500, "段落不存在")
if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists():
raise AppApiException(500, "问题不存在")
def association(self, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid(raise_exception=True)
problem = QuerySet(Problem).filter(id=self.data.get("problem_id")).first()
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
document_id=self.data.get('document_id'),
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'),
problem_id=problem.id)
problem_paragraph_mapping.save()
if with_embedding:
2024-08-21 06:46:11 +00:00
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
embedding_by_problem({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
}, model_id)
def un_association(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'),
problem_id=self.data.get(
'problem_id')).first()
problem_paragraph_mapping_id = problem_paragraph_mapping.id
problem_paragraph_mapping.delete()
2024-08-21 06:46:11 +00:00
delete_embedding_by_source(problem_paragraph_mapping_id)
return True
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
class Batch(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
@transaction.atomic
def batch_delete(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Paragraph, raise_exception=True)
self.is_valid(raise_exception=True)
paragraph_id_list = instance.get("id_list")
QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete()
update_document_char_length(self.data.get('document_id'))
# 删除向量库
2024-08-21 06:46:11 +00:00
delete_embedding_by_paragraph_ids(paragraph_id_list)
return True
class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
target_dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标知识库id"))
target_document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标文档id"))
paragraph_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("段落列表"),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid("段落id")))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_list = QuerySet(Document).filter(
id__in=[self.data.get('document_id'), self.data.get('target_document_id')])
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
if document_id == target_document_id:
raise AppApiException(5000, "需要迁移的文档和目标文档一致")
if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1:
raise AppApiException(5000, f"文档id不存在【{self.data.get('document_id')}")
if len([document for document in document_list if
str(document.id) == self.data.get('target_document_id')]) < 1:
raise AppApiException(5000, f"目标文档id不存在【{self.data.get('target_document_id')}")
@transaction.atomic
def migrate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
target_dataset_id = self.data.get('target_dataset_id')
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
paragraph_id_list = self.data.get('paragraph_id_list')
paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id=document_id,
id__in=paragraph_id_list)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
# 同数据集迁移
if target_dataset_id == dataset_id:
if len(problem_paragraph_mapping_list):
problem_paragraph_mapping_list = [
self.update_problem_paragraph_mapping(target_document_id,
problem_paragraph_mapping) for problem_paragraph_mapping
in
problem_paragraph_mapping_list]
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
2024-08-21 06:46:11 +00:00
update_embedding_document_id([paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id, None)
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
else:
problem_list = QuerySet(Problem).filter(
id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
problem_paragraph_mapping_list])
# 目标数据集问题
target_problem_list = list(
QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
dataset_id=target_dataset_id))
target_handle_problem_list = [
self.get_target_dataset_problem(target_dataset_id, target_document_id, problem_paragraph_mapping,
problem_list, target_problem_list) for
problem_paragraph_mapping
in
problem_paragraph_mapping_list]
create_problem_list = [problem for problem, is_create in target_handle_problem_list if
is_create is not None and is_create]
# 插入问题
QuerySet(Problem).bulk_create(create_problem_list)
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'dataset_id', 'document_id'])
2024-07-18 07:44:48 +00:00
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
2024-08-21 06:46:11 +00:00
embedding_model_id = None
2024-07-18 07:44:48 +00:00
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
2024-08-21 06:46:11 +00:00
embedding_model_id = str(target_dataset.embedding_mode_id)
2024-07-18 07:44:48 +00:00
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
2024-07-18 07:44:48 +00:00
# 修改向量段落信息
2024-08-21 06:46:11 +00:00
update_embedding_document_id(pid_list, target_document_id, target_dataset_id, embedding_model_id)
2024-07-18 07:44:48 +00:00
update_document_char_length(document_id)
update_document_char_length(target_document_id)
@staticmethod
def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping):
problem_paragraph_mapping.document_id = target_document_id
return problem_paragraph_mapping
@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
target_document_id: str,
problem_paragraph_mapping,
source_problem_list,
target_problem_list):
source_problem_list = [source_problem for source_problem in source_problem_list if
source_problem.id == problem_paragraph_mapping.problem_id]
problem_paragraph_mapping.dataset_id = target_dataset_id
problem_paragraph_mapping.document_id = target_document_id
if len(source_problem_list) > 0:
problem_content = source_problem_list[-1].content
problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
if len(problem_list) > 0:
problem = problem_list[-1]
problem_paragraph_mapping.problem_id = problem.id
return problem, False
else:
problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content)
target_problem_list.append(problem)
problem_paragraph_mapping.problem_id = problem.id
return problem, True
return None
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='target_dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id'),
openapi.Parameter(name='target_document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id')
]
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title='段落id列表',
description="段落id列表"
)
class Operate(ApiMixin, serializers.Serializer):
# 段落id
2024-03-04 02:12:18 +00:00
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"段落id"))
2023-12-18 03:32:29 +00:00
# 知识库id
2024-03-04 02:12:18 +00:00
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
# 文档id
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
@staticmethod
2024-07-18 02:26:16 +00:00
def post_embedding(paragraph, instance, dataset_id):
if 'is_active' in instance and instance.get('is_active') is not None:
2024-08-21 06:46:11 +00:00
(enable_embedding_by_paragraph if instance.get(
'is_active') else disable_embedding_by_paragraph)(paragraph.get('id'))
else:
2024-08-21 06:46:11 +00:00
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
embedding_by_paragraph(paragraph.get('id'), model_id)
return paragraph
@post(post_embedding)
@transaction.atomic
def edit(self, instance: Dict):
self.is_valid()
EditParagraphSerializers(data=instance).is_valid(raise_exception=True)
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
update_keys = ['title', 'content', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_paragraph.__setattr__(update_key, instance.get(update_key))
if 'problem_list' in instance:
update_problem_list = list(
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
# 问题集合
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
# 校验前端 携带过来的id
for update_problem in update_problem_list:
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
raise AppApiException(500, update_problem.get('id') + '问题id不存在')
# 对比需要删除的问题
delete_problem_list = list(filter(
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
# 删除问题
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
delete_problem_list) > 0 else None
# 插入新的问题
QuerySet(Problem).bulk_create(
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
p in create_problem_list]) if len(create_problem_list) else None
# 修改问题集合
QuerySet(Problem).bulk_update(
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
['content']) if len(
update_problem_list) > 0 else None
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
2024-07-18 02:26:16 +00:00
return self.one(), instance, self.data.get('dataset_id')
def get_problem_list(self):
ProblemParagraphMapping(ProblemParagraphMapping)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get("paragraph_id"))
if len(problem_paragraph_mapping) > 0:
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
return []
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
'problem_list': self.get_problem_list()}
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id = self.data.get('paragraph_id')
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
update_document_char_length(self.data.get('document_id'))
2024-08-21 06:46:11 +00:00
delete_embedding_by_paragraph(paragraph_id)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_response_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id',
description="段落id")]
class Create(ApiMixin, serializers.Serializer):
2024-03-04 02:12:18 +00:00
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
2024-03-04 02:12:18 +00:00
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Document).filter(id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "文档id不正确")
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
self.is_valid()
dataset_id = self.data.get("dataset_id")
document_id = self.data.get('document_id')
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id).
to_problem_model_list())
# 插入段落
paragraph_problem_model.get('paragraph').save()
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 插入问题关联关系
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 修改长度
update_document_char_length(document_id)
if with_embedding:
2024-08-21 06:46:11 +00:00
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
embedding_by_paragraph(str(paragraph.id), model_id)
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)
@staticmethod
def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict):
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem_paragraph_object_list = [
ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in
(instance.get('problem_list') if 'problem_list' in instance else [])]
return {'paragraph': paragraph,
'problem_paragraph_object_list': problem_paragraph_object_list}
@staticmethod
def or_get(exists_problem_list, content, dataset_id):
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
return exists[0]
else:
return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
2023-12-18 03:32:29 +00:00
description='知识库id'),
openapi.Parameter(name='document_id', in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description="文档id")
]
class Query(ApiMixin, serializers.Serializer):
2024-03-04 02:12:18 +00:00
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
2024-03-04 02:12:18 +00:00
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
2024-03-04 02:12:18 +00:00
title = serializers.CharField(required=False, error_messages=ErrMessage.char(
"段落标题"))
content = serializers.CharField(required=False)
def get_query_set(self):
query_set = QuerySet(model=Paragraph)
query_set = query_set.filter(
**{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")})
if 'title' in self.data:
query_set = query_set.filter(
**{'title__icontains': self.data.get('title')})
if 'content' in self.data:
query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
return query_set
def list(self):
return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='title',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='标题'),
openapi.Parameter(name='content',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='内容')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
2023-12-13 09:14:43 +00:00
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
'document_id', 'title',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容", default='段落内容'),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
description="标题", default="xxx的描述"),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
2023-12-18 03:32:29 +00:00
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
class BatchGenerateRelated(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
@transaction.atomic
def batch_generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id_list = instance.get("paragraph_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt)