223 lines
11 KiB
Python
223 lines
11 KiB
Python
|
|
# coding=utf-8
|
|||
|
|
"""
|
|||
|
|
@project: maxkb
|
|||
|
|
@Author:虎
|
|||
|
|
@file: problem_serializers.py
|
|||
|
|
@date:2023/10/23 13:55
|
|||
|
|
@desc:
|
|||
|
|
"""
|
|||
|
|
import uuid
|
|||
|
|
from typing import Dict
|
|||
|
|
|
|||
|
|
from django.db.models import QuerySet
|
|||
|
|
from drf_yasg import openapi
|
|||
|
|
from rest_framework import serializers
|
|||
|
|
|
|||
|
|
from common.event.listener_manage import ListenerManagement
|
|||
|
|
from common.exception.app_exception import AppApiException
|
|||
|
|
from common.mixins.api_mixin import ApiMixin
|
|||
|
|
from dataset.models import Problem, Paragraph
|
|||
|
|
from embedding.models import SourceType
|
|||
|
|
from embedding.vector.pg_vector import PGVector
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProblemSerializer(serializers.ModelSerializer):
|
|||
|
|
class Meta:
|
|||
|
|
model = Problem
|
|||
|
|
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
|
|||
|
|
'create_time', 'update_time']
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
|
|
id = serializers.CharField(required=False)
|
|||
|
|
|
|||
|
|
content = serializers.CharField(required=True)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_request_body_api():
|
|||
|
|
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
|||
|
|
required=["content"],
|
|||
|
|
properties={
|
|||
|
|
'id': openapi.Schema(
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
title="问题id,修改的时候传递,创建的时候不传"),
|
|||
|
|
'content': openapi.Schema(
|
|||
|
|
type=openapi.TYPE_STRING, title="内容")
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
|
|
class Create(ApiMixin, serializers.Serializer):
|
|||
|
|
dataset_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
document_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
paragraph_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
|||
|
|
if with_valid:
|
|||
|
|
self.is_valid()
|
|||
|
|
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
|||
|
|
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
|
|||
|
|
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
|
|||
|
|
content=instance.get('content'))
|
|||
|
|
problem.save()
|
|||
|
|
if with_embedding:
|
|||
|
|
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
|||
|
|
'is_active': True,
|
|||
|
|
'source_type': SourceType.PROBLEM,
|
|||
|
|
'source_id': problem.id,
|
|||
|
|
'document_id': self.data.get('document_id'),
|
|||
|
|
'paragraph_id': self.data.get('paragraph_id'),
|
|||
|
|
'dataset_id': self.data.get('dataset_id')})
|
|||
|
|
|
|||
|
|
return ProblemSerializers.Operate(
|
|||
|
|
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
|||
|
|
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_request_body_api():
|
|||
|
|
return ProblemInstanceSerializer.get_request_body_api()
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_request_params_api():
|
|||
|
|
return [openapi.Parameter(name='dataset_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='数据集id'),
|
|||
|
|
openapi.Parameter(name='document_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='文档id'),
|
|||
|
|
openapi.Parameter(name='paragraph_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='段落id')]
|
|||
|
|
|
|||
|
|
class Query(ApiMixin, serializers.Serializer):
|
|||
|
|
dataset_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
document_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
paragraph_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
def is_valid(self, *, raise_exception=True):
|
|||
|
|
super().is_valid(raise_exception=True)
|
|||
|
|
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
|||
|
|
raise AppApiException(500, "段落id不存在")
|
|||
|
|
|
|||
|
|
def get_query_set(self):
|
|||
|
|
dataset_id = self.data.get('dataset_id')
|
|||
|
|
document_id = self.data.get('document_id')
|
|||
|
|
paragraph_id = self.data.get("paragraph_id")
|
|||
|
|
return QuerySet(Problem).filter(
|
|||
|
|
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
|
|||
|
|
|
|||
|
|
def list(self, with_valid=False):
|
|||
|
|
"""
|
|||
|
|
获取问题列表
|
|||
|
|
:param with_valid: 是否校验
|
|||
|
|
:return: 问题列表
|
|||
|
|
"""
|
|||
|
|
if with_valid:
|
|||
|
|
self.is_valid(raise_exception=True)
|
|||
|
|
query_set = self.get_query_set()
|
|||
|
|
return [ProblemSerializer(p).data for p in query_set]
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_request_params_api():
|
|||
|
|
return [openapi.Parameter(name='dataset_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='数据集id'),
|
|||
|
|
openapi.Parameter(name='document_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='文档id')
|
|||
|
|
, openapi.Parameter(name='paragraph_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='段落id')]
|
|||
|
|
|
|||
|
|
class Operate(ApiMixin, serializers.Serializer):
|
|||
|
|
dataset_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
document_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
paragraph_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
problem_id = serializers.UUIDField(required=True)
|
|||
|
|
|
|||
|
|
def delete(self, with_valid=False):
|
|||
|
|
if with_valid:
|
|||
|
|
self.is_valid(raise_exception=True)
|
|||
|
|
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
|
|||
|
|
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
|
|||
|
|
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def one(self, with_valid=False):
|
|||
|
|
if with_valid:
|
|||
|
|
self.is_valid(raise_exception=True)
|
|||
|
|
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_request_params_api():
|
|||
|
|
return [openapi.Parameter(name='dataset_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='数据集id'),
|
|||
|
|
openapi.Parameter(name='document_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='文档id')
|
|||
|
|
, openapi.Parameter(name='paragraph_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='段落id'),
|
|||
|
|
openapi.Parameter(name='problem_id',
|
|||
|
|
in_=openapi.IN_PATH,
|
|||
|
|
type=openapi.TYPE_STRING,
|
|||
|
|
required=True,
|
|||
|
|
description='问题id')
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_response_body_api():
|
|||
|
|
return openapi.Schema(
|
|||
|
|
type=openapi.TYPE_OBJECT,
|
|||
|
|
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
|
|||
|
|
'document_id',
|
|||
|
|
'create_time', 'update_time'],
|
|||
|
|
properties={
|
|||
|
|
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
|||
|
|
description="id", default="xx"),
|
|||
|
|
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
|||
|
|
description="问题内容", default='问题内容'),
|
|||
|
|
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
|||
|
|
default=1),
|
|||
|
|
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
|||
|
|
description="点赞数量", default=1),
|
|||
|
|
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
|||
|
|
description="点踩数", default=1),
|
|||
|
|
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
|||
|
|
description="文档id", default='xxx'),
|
|||
|
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
|||
|
|
description="修改时间",
|
|||
|
|
default="1970-01-01 00:00:00"),
|
|||
|
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
|||
|
|
description="创建时间",
|
|||
|
|
default="1970-01-01 00:00:00"
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
)
|