2023-11-16 05:16:27 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: chat_serializers.py
|
|
|
|
|
|
@date:2023/11/14 9:59
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
|
|
|
|
|
import datetime
|
|
|
|
|
|
import os
|
2024-01-25 11:19:46 +00:00
|
|
|
|
import re
|
2023-11-16 05:16:27 +00:00
|
|
|
|
import uuid
|
2024-01-16 08:46:54 +00:00
|
|
|
|
from functools import reduce
|
2024-08-23 07:36:38 +00:00
|
|
|
|
from io import BytesIO
|
2024-03-06 05:43:45 +00:00
|
|
|
|
from typing import Dict
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-08-23 07:36:38 +00:00
|
|
|
|
import openpyxl
|
2024-01-25 11:19:46 +00:00
|
|
|
|
from django.core import validators
|
2024-04-02 08:19:36 +00:00
|
|
|
|
from django.core.cache import caches
|
2024-01-25 11:19:46 +00:00
|
|
|
|
from django.db import transaction, models
|
|
|
|
|
|
from django.db.models import QuerySet, Q
|
2024-09-19 09:11:50 +00:00
|
|
|
|
from django.http import StreamingHttpResponse
|
2024-10-15 06:01:05 +00:00
|
|
|
|
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
from application.flow.workflow_manage import Flow
|
|
|
|
|
|
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord, WorkFlowVersion, \
|
|
|
|
|
|
ApplicationTypeChoices
|
2024-04-25 06:05:59 +00:00
|
|
|
|
from application.models.api_key_model import ApplicationAccessToken
|
2024-01-16 08:46:54 +00:00
|
|
|
|
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
|
|
|
|
|
ModelSettingSerializer
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from application.serializers.chat_message_serializers import ChatInfo
|
2024-04-25 08:17:29 +00:00
|
|
|
|
from common.constants.permission_constants import RoleConstants
|
2024-01-25 11:19:46 +00:00
|
|
|
|
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from common.exception.app_exception import AppApiException
|
2024-02-19 02:55:13 +00:00
|
|
|
|
from common.util.common import post
|
2024-03-04 02:12:18 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from common.util.file_util import get_file_content
|
|
|
|
|
|
from common.util.lock import try_lock, un_lock
|
2024-03-29 10:27:08 +00:00
|
|
|
|
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
2024-08-23 09:46:05 +00:00
|
|
|
|
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
|
2024-03-04 09:01:16 +00:00
|
|
|
|
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
2024-11-12 09:49:22 +00:00
|
|
|
|
from embedding.task import embedding_by_paragraph, embedding_by_paragraph_list
|
2024-08-23 09:46:05 +00:00
|
|
|
|
from setting.models import Model
|
|
|
|
|
|
from setting.models_provider import get_model_credential
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from smartdoc.conf import PROJECT_DIR
|
|
|
|
|
|
|
2024-07-25 02:41:38 +00:00
|
|
|
|
chat_cache = caches['chat_cache']
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
class WorkFlowSerializers(serializers.Serializer):
|
|
|
|
|
|
nodes = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("节点"))
|
|
|
|
|
|
edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线"))
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-08-23 09:46:05 +00:00
|
|
|
|
def valid_model_params_setting(model_id, model_params_setting):
|
|
|
|
|
|
if model_id is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
|
|
|
|
|
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
|
|
|
|
|
model_params_setting_form = credential.get_model_params_setting_form(model.model_name)
|
|
|
|
|
|
if model_params_setting is None or len(model_params_setting.keys()) == 0:
|
|
|
|
|
|
model_params_setting = model_params_setting_form.get_default_form_data()
|
|
|
|
|
|
credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting)
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class ChatSerializers(serializers.Serializer):
|
2023-12-06 08:29:14 +00:00
|
|
|
|
class Operate(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
|
|
|
|
|
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
2023-12-06 08:29:14 +00:00
|
|
|
|
|
2024-06-13 07:17:36 +00:00
|
|
|
|
def logic_delete(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
QuerySet(Chat).filter(id=self.data.get('chat_id'), application_id=self.data.get('application_id')).update(
|
|
|
|
|
|
is_deleted=True)
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
2023-12-06 08:29:14 +00:00
|
|
|
|
def delete(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
QuerySet(Chat).filter(id=self.data.get('chat_id'), application_id=self.data.get('application_id')).delete()
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
2024-05-20 09:50:14 +00:00
|
|
|
|
class ClientChatHistory(serializers.Serializer):
|
|
|
|
|
|
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
|
|
|
|
|
client_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("客户端id"))
|
|
|
|
|
|
|
|
|
|
|
|
def page(self, current_page: int, page_size: int, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
queryset = QuerySet(Chat).filter(client_id=self.data.get('client_id'),
|
2024-06-13 07:17:36 +00:00
|
|
|
|
application_id=self.data.get('application_id'),
|
|
|
|
|
|
is_deleted=False)
|
2024-05-20 09:50:14 +00:00
|
|
|
|
queryset = queryset.order_by('-create_time')
|
|
|
|
|
|
return page_search(current_page, page_size, queryset, lambda row: ChatSerializerModel(row).data)
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class Query(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
abstract = serializers.CharField(required=False, error_messages=ErrMessage.char("摘要"))
|
2024-10-15 10:47:07 +00:00
|
|
|
|
start_time = serializers.DateField(format='%Y-%m-%d', error_messages=ErrMessage.date("开始时间"))
|
|
|
|
|
|
end_time = serializers.DateField(format='%Y-%m-%d', error_messages=ErrMessage.date("结束时间"))
|
2024-03-04 02:12:18 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
|
|
|
|
|
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
|
|
|
|
|
min_star = serializers.IntegerField(required=False, min_value=0,
|
|
|
|
|
|
error_messages=ErrMessage.integer("最小点赞数"))
|
|
|
|
|
|
min_trample = serializers.IntegerField(required=False, min_value=0,
|
|
|
|
|
|
error_messages=ErrMessage.integer("最小点踩数"))
|
|
|
|
|
|
comparer = serializers.CharField(required=False, error_messages=ErrMessage.char("比较器"), validators=[
|
2024-01-25 11:19:46 +00:00
|
|
|
|
validators.RegexValidator(regex=re.compile("^and|or$"),
|
|
|
|
|
|
message="只支持and|or", code=500)
|
|
|
|
|
|
])
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def get_end_time(self):
|
2024-10-15 10:47:07 +00:00
|
|
|
|
return datetime.datetime.combine(
|
|
|
|
|
|
datetime.datetime.strptime(self.data.get('end_time'), '%Y-%m-%d'),
|
|
|
|
|
|
datetime.datetime.max.time())
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-10-15 10:47:07 +00:00
|
|
|
|
def get_start_time(self):
|
|
|
|
|
|
return self.data.get('start_time')
|
|
|
|
|
|
|
|
|
|
|
|
def get_query_set(self, select_ids=None):
|
2023-11-16 05:16:27 +00:00
|
|
|
|
end_time = self.get_end_time()
|
2024-10-15 10:47:07 +00:00
|
|
|
|
start_time = self.get_start_time()
|
2024-01-25 11:19:46 +00:00
|
|
|
|
query_set = QuerySet(model=get_dynamics_model(
|
2024-04-02 08:19:36 +00:00
|
|
|
|
{'application_chat.application_id': models.CharField(),
|
|
|
|
|
|
'application_chat.abstract': models.CharField(),
|
2024-01-25 11:19:46 +00:00
|
|
|
|
"star_num": models.IntegerField(),
|
|
|
|
|
|
'trample_num': models.IntegerField(),
|
|
|
|
|
|
'comparer': models.CharField(),
|
2024-10-15 10:47:07 +00:00
|
|
|
|
'application_chat.create_time': models.DateTimeField(),
|
|
|
|
|
|
'application_chat.id': models.UUIDField(), }))
|
2024-01-25 11:19:46 +00:00
|
|
|
|
|
2024-04-02 08:19:36 +00:00
|
|
|
|
base_query_dict = {'application_chat.application_id': self.data.get("application_id"),
|
2024-10-15 10:47:07 +00:00
|
|
|
|
'application_chat.create_time__gte': start_time,
|
|
|
|
|
|
'application_chat.create_time__lte': end_time,
|
|
|
|
|
|
}
|
2023-12-06 08:29:14 +00:00
|
|
|
|
if 'abstract' in self.data and self.data.get('abstract') is not None:
|
2024-04-15 05:58:40 +00:00
|
|
|
|
base_query_dict['application_chat.abstract__icontains'] = self.data.get('abstract')
|
2024-10-15 10:47:07 +00:00
|
|
|
|
|
|
|
|
|
|
if select_ids is not None and len(select_ids) > 0:
|
|
|
|
|
|
base_query_dict['application_chat.id__in'] = select_ids
|
2024-01-25 11:19:46 +00:00
|
|
|
|
base_condition = Q(**base_query_dict)
|
|
|
|
|
|
min_star_query = None
|
|
|
|
|
|
min_trample_query = None
|
|
|
|
|
|
if 'min_star' in self.data and self.data.get('min_star') is not None:
|
|
|
|
|
|
min_star_query = Q(star_num__gte=self.data.get('min_star'))
|
|
|
|
|
|
if 'min_trample' in self.data and self.data.get('min_trample') is not None:
|
|
|
|
|
|
min_trample_query = Q(trample_num__gte=self.data.get('min_trample'))
|
|
|
|
|
|
if min_star_query is not None and min_trample_query is not None:
|
|
|
|
|
|
if self.data.get(
|
|
|
|
|
|
'comparer') is not None and self.data.get('comparer') == 'or':
|
|
|
|
|
|
condition = base_condition & (min_star_query | min_trample_query)
|
|
|
|
|
|
else:
|
|
|
|
|
|
condition = base_condition & (min_star_query & min_trample_query)
|
|
|
|
|
|
elif min_star_query is not None:
|
|
|
|
|
|
condition = base_condition & min_star_query
|
|
|
|
|
|
elif min_trample_query is not None:
|
|
|
|
|
|
condition = base_condition & min_trample_query
|
|
|
|
|
|
else:
|
|
|
|
|
|
condition = base_condition
|
2024-04-02 08:19:36 +00:00
|
|
|
|
return query_set.filter(condition).order_by("-application_chat.create_time")
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def list(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
return native_search(self.get_query_set(), select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
|
2024-01-25 11:19:46 +00:00
|
|
|
|
with_table_name=False)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-10-29 09:29:19 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def paragraph_list_to_string(paragraph_list):
|
|
|
|
|
|
return "\n**********\n".join(
|
2024-10-30 12:56:18 +00:00
|
|
|
|
[f"{paragraph.get('title')}:\n{paragraph.get('content')}" for paragraph in
|
|
|
|
|
|
paragraph_list] if paragraph_list is not None else '')
|
2024-10-29 09:29:19 +00:00
|
|
|
|
|
2024-04-02 08:19:36 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def to_row(row: Dict):
|
|
|
|
|
|
details = row.get('details')
|
|
|
|
|
|
padding_problem_text = details.get('problem_padding').get(
|
|
|
|
|
|
'padding_problem_text') if 'problem_padding' in details and 'padding_problem_text' in details.get(
|
|
|
|
|
|
'problem_padding') else ""
|
2024-10-29 09:29:19 +00:00
|
|
|
|
search_dataset_node_list = [(key, node) for key, node in details.items() if
|
|
|
|
|
|
node.get("type") == 'search-dataset-node' or node.get(
|
|
|
|
|
|
"step_type") == 'search_step']
|
|
|
|
|
|
reference_paragraph_len = '\n'.join([str(len(node.get('paragraph_list',
|
|
|
|
|
|
[]))) if key == 'search_step' else node.get(
|
2024-10-30 12:56:18 +00:00
|
|
|
|
'name') + ':' + str(
|
|
|
|
|
|
len(node.get('paragraph_list', [])) if node.get('paragraph_list', []) is not None else '0') for
|
2024-10-29 09:29:19 +00:00
|
|
|
|
key, node in search_dataset_node_list])
|
|
|
|
|
|
reference_paragraph = '\n----------\n'.join(
|
|
|
|
|
|
[ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list',
|
|
|
|
|
|
[])) if key == 'search_step' else node.get(
|
|
|
|
|
|
'name') + ':\n' + ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list',
|
2024-10-30 12:56:18 +00:00
|
|
|
|
[])) for
|
2024-10-29 09:29:19 +00:00
|
|
|
|
key, node in search_dataset_node_list])
|
2024-04-02 08:19:36 +00:00
|
|
|
|
improve_paragraph_list = row.get('improve_paragraph_list')
|
|
|
|
|
|
vote_status_map = {'-1': '未投票', '0': '赞同', '1': '反对'}
|
|
|
|
|
|
return [str(row.get('chat_id')), row.get('abstract'), row.get('problem_text'), padding_problem_text,
|
2024-10-29 09:29:19 +00:00
|
|
|
|
row.get('answer_text'), vote_status_map.get(row.get('vote_status')), reference_paragraph_len,
|
|
|
|
|
|
reference_paragraph,
|
2024-04-02 08:19:36 +00:00
|
|
|
|
"\n".join([
|
|
|
|
|
|
f"{improve_paragraph_list[index].get('title')}\n{improve_paragraph_list[index].get('content')}"
|
|
|
|
|
|
for index in range(len(improve_paragraph_list))]),
|
|
|
|
|
|
row.get('message_tokens') + row.get('answer_tokens'), row.get('run_time'),
|
2024-10-25 07:30:54 +00:00
|
|
|
|
str(row.get('create_time').strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
|
|
)]
|
2024-04-02 08:19:36 +00:00
|
|
|
|
|
2024-10-15 10:47:07 +00:00
|
|
|
|
def export(self, data, with_valid=True):
|
2024-04-02 08:19:36 +00:00
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2024-08-23 07:36:38 +00:00
|
|
|
|
|
2024-10-15 10:47:07 +00:00
|
|
|
|
data_list = native_search(self.get_query_set(data.get('select_ids')),
|
2024-08-23 07:36:38 +00:00
|
|
|
|
select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
|
|
|
|
|
'export_application_chat.sql')),
|
2024-04-02 08:19:36 +00:00
|
|
|
|
with_table_name=False)
|
|
|
|
|
|
|
2024-08-23 07:36:38 +00:00
|
|
|
|
batch_size = 500
|
|
|
|
|
|
|
|
|
|
|
|
def stream_response():
|
|
|
|
|
|
workbook = openpyxl.Workbook()
|
|
|
|
|
|
worksheet = workbook.active
|
|
|
|
|
|
worksheet.title = 'Sheet1'
|
|
|
|
|
|
|
|
|
|
|
|
headers = ['会话ID', '摘要', '用户问题', '优化后问题', '回答', '用户反馈', '引用分段数',
|
|
|
|
|
|
'分段标题+内容',
|
|
|
|
|
|
'标注', '消耗tokens', '耗时(s)', '提问时间']
|
|
|
|
|
|
for col_idx, header in enumerate(headers, 1):
|
|
|
|
|
|
cell = worksheet.cell(row=1, column=col_idx)
|
|
|
|
|
|
cell.value = header
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(0, len(data_list), batch_size):
|
|
|
|
|
|
batch_data = data_list[i:i + batch_size]
|
|
|
|
|
|
|
|
|
|
|
|
for row_idx, row in enumerate(batch_data, start=i + 2):
|
|
|
|
|
|
for col_idx, value in enumerate(self.to_row(row), 1):
|
|
|
|
|
|
cell = worksheet.cell(row=row_idx, column=col_idx)
|
2024-10-15 06:01:05 +00:00
|
|
|
|
if isinstance(value, str):
|
|
|
|
|
|
value = re.sub(ILLEGAL_CHARACTERS_RE, '', value)
|
2024-08-23 07:36:38 +00:00
|
|
|
|
cell.value = value
|
|
|
|
|
|
|
|
|
|
|
|
output = BytesIO()
|
|
|
|
|
|
workbook.save(output)
|
|
|
|
|
|
output.seek(0)
|
|
|
|
|
|
yield output.getvalue()
|
|
|
|
|
|
output.close()
|
|
|
|
|
|
workbook.close()
|
|
|
|
|
|
|
|
|
|
|
|
response = StreamingHttpResponse(stream_response(),
|
|
|
|
|
|
content_type='application/vnd.open.xmlformats-officedocument.spreadsheetml.sheet')
|
|
|
|
|
|
response['Content-Disposition'] = 'attachment; filename="data.xlsx"'
|
2024-04-02 08:19:36 +00:00
|
|
|
|
return response
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
def page(self, current_page: int, page_size: int, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
|
|
|
|
|
|
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
|
2024-01-25 11:19:46 +00:00
|
|
|
|
with_table_name=False)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
class OpenChat(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
user_id = self.data.get('user_id')
|
|
|
|
|
|
application_id = self.data.get('application_id')
|
|
|
|
|
|
if not QuerySet(Application).filter(id=application_id, user_id=user_id).exists():
|
|
|
|
|
|
raise AppApiException(500, '应用不存在')
|
|
|
|
|
|
|
|
|
|
|
|
def open(self):
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
application_id = self.data.get('application_id')
|
|
|
|
|
|
application = QuerySet(Application).get(id=application_id)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
if application.type == ApplicationTypeChoices.SIMPLE:
|
|
|
|
|
|
return self.open_simple(application)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self.open_work_flow(application)
|
|
|
|
|
|
|
|
|
|
|
|
def open_work_flow(self, application):
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
application_id = self.data.get('application_id')
|
|
|
|
|
|
chat_id = str(uuid.uuid1())
|
|
|
|
|
|
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by(
|
|
|
|
|
|
'-create_time')[0:1].first()
|
|
|
|
|
|
if work_flow_version is None:
|
|
|
|
|
|
raise AppApiException(500, "应用未发布,请发布后再使用")
|
|
|
|
|
|
chat_cache.set(chat_id,
|
2024-08-26 02:46:52 +00:00
|
|
|
|
ChatInfo(chat_id, [],
|
2024-07-01 01:45:59 +00:00
|
|
|
|
[],
|
|
|
|
|
|
application, work_flow_version), timeout=60 * 30)
|
|
|
|
|
|
return chat_id
|
|
|
|
|
|
|
|
|
|
|
|
def open_simple(self, application):
|
2024-08-23 09:46:05 +00:00
|
|
|
|
valid_model_params_setting(application.model_id, application.model_params_setting)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
application_id = self.data.get('application_id')
|
2023-11-16 05:16:27 +00:00
|
|
|
|
dataset_id_list = [str(row.dataset_id) for row in
|
|
|
|
|
|
QuerySet(ApplicationDatasetMapping).filter(
|
|
|
|
|
|
application_id=application_id)]
|
|
|
|
|
|
chat_id = str(uuid.uuid1())
|
|
|
|
|
|
chat_cache.set(chat_id,
|
2024-08-26 03:21:40 +00:00
|
|
|
|
ChatInfo(chat_id, dataset_id_list,
|
2023-11-16 05:16:27 +00:00
|
|
|
|
[str(document.id) for document in
|
|
|
|
|
|
QuerySet(Document).filter(
|
|
|
|
|
|
dataset_id__in=dataset_id_list,
|
|
|
|
|
|
is_active=False)],
|
2024-01-16 08:46:54 +00:00
|
|
|
|
application), timeout=60 * 30)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
return chat_id
|
|
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
class OpenWorkFlowChat(serializers.Serializer):
|
|
|
|
|
|
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
|
2024-07-18 11:11:29 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
def open(self):
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
work_flow = self.data.get('work_flow')
|
|
|
|
|
|
Flow.new_instance(work_flow).is_valid()
|
|
|
|
|
|
chat_id = str(uuid.uuid1())
|
|
|
|
|
|
application = Application(id=None, dialogue_number=3, model=None,
|
|
|
|
|
|
dataset_setting={},
|
|
|
|
|
|
model_setting={},
|
|
|
|
|
|
problem_optimization=None,
|
2024-07-18 11:11:29 +00:00
|
|
|
|
type=ApplicationTypeChoices.WORK_FLOW,
|
|
|
|
|
|
user_id=self.data.get('user_id')
|
2024-07-01 01:45:59 +00:00
|
|
|
|
)
|
|
|
|
|
|
work_flow_version = WorkFlowVersion(work_flow=work_flow)
|
|
|
|
|
|
chat_cache.set(chat_id,
|
2024-08-26 03:21:40 +00:00
|
|
|
|
ChatInfo(chat_id, [],
|
2024-07-01 01:45:59 +00:00
|
|
|
|
[],
|
|
|
|
|
|
application, work_flow_version), timeout=60 * 30)
|
|
|
|
|
|
return chat_id
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class OpenTempChat(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-06 10:11:11 +00:00
|
|
|
|
id = serializers.UUIDField(required=False, allow_null=True,
|
|
|
|
|
|
error_messages=ErrMessage.uuid("应用id"))
|
2024-04-28 09:09:12 +00:00
|
|
|
|
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
|
|
|
|
|
error_messages=ErrMessage.uuid("模型id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
multiple_rounds_dialogue = serializers.BooleanField(required=True,
|
|
|
|
|
|
error_messages=ErrMessage.boolean("多轮会话"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
|
|
|
|
|
|
error_messages=ErrMessage.list("关联数据集"))
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 数据集相关设置
|
|
|
|
|
|
dataset_setting = DatasetSettingSerializer(required=True)
|
|
|
|
|
|
# 模型相关设置
|
|
|
|
|
|
model_setting = ModelSettingSerializer(required=True)
|
|
|
|
|
|
# 问题补全
|
2024-03-04 02:12:18 +00:00
|
|
|
|
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
|
2024-08-23 09:46:05 +00:00
|
|
|
|
# 模型相关设置
|
2024-08-26 03:21:40 +00:00
|
|
|
|
model_params_setting = serializers.JSONField(required=False, error_messages=ErrMessage.dict("模型参数相关设置"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
2024-03-06 10:11:11 +00:00
|
|
|
|
user_id = self.get_user_id()
|
2023-11-16 05:16:27 +00:00
|
|
|
|
ModelDatasetAssociation(
|
2024-03-06 10:11:11 +00:00
|
|
|
|
data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
2023-11-16 05:16:27 +00:00
|
|
|
|
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
2024-03-06 10:11:11 +00:00
|
|
|
|
return user_id
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_id(self):
|
|
|
|
|
|
if 'id' in self.data and self.data.get('id') is not None:
|
|
|
|
|
|
application = QuerySet(Application).filter(id=self.data.get('id')).first()
|
|
|
|
|
|
if application is None:
|
|
|
|
|
|
raise AppApiException(500, "应用不存在")
|
|
|
|
|
|
return application.user_id
|
|
|
|
|
|
return self.data.get('user_id')
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def open(self):
|
2024-03-06 10:11:11 +00:00
|
|
|
|
user_id = self.is_valid(raise_exception=True)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
chat_id = str(uuid.uuid1())
|
2024-04-28 09:09:12 +00:00
|
|
|
|
model_id = self.data.get('model_id')
|
2023-11-16 05:16:27 +00:00
|
|
|
|
dataset_id_list = self.data.get('dataset_id_list')
|
2024-08-07 07:12:32 +00:00
|
|
|
|
dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0
|
2024-08-23 09:46:05 +00:00
|
|
|
|
valid_model_params_setting(model_id, self.data.get('model_params_setting'))
|
2024-08-07 07:12:32 +00:00
|
|
|
|
application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id,
|
2024-01-16 08:46:54 +00:00
|
|
|
|
dataset_setting=self.data.get('dataset_setting'),
|
|
|
|
|
|
model_setting=self.data.get('model_setting'),
|
2024-07-18 11:11:29 +00:00
|
|
|
|
problem_optimization=self.data.get('problem_optimization'),
|
2024-08-23 09:46:05 +00:00
|
|
|
|
model_params_setting=self.data.get('model_params_setting'),
|
2024-07-18 11:11:29 +00:00
|
|
|
|
user_id=user_id)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
chat_cache.set(chat_id,
|
2024-08-26 03:21:40 +00:00
|
|
|
|
ChatInfo(chat_id, dataset_id_list,
|
2023-11-16 05:16:27 +00:00
|
|
|
|
[str(document.id) for document in
|
|
|
|
|
|
QuerySet(Document).filter(
|
|
|
|
|
|
dataset_id__in=dataset_id_list,
|
|
|
|
|
|
is_active=False)],
|
2024-01-16 08:46:54 +00:00
|
|
|
|
application), timeout=60 * 30)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
return chat_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatRecordSerializerModel(serializers.ModelSerializer):
|
|
|
|
|
|
class Meta:
|
|
|
|
|
|
model = ChatRecord
|
2024-01-16 08:46:54 +00:00
|
|
|
|
fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text',
|
2024-11-13 02:37:16 +00:00
|
|
|
|
'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index','answer_text_list',
|
2024-01-17 08:49:51 +00:00
|
|
|
|
'create_time', 'update_time']
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-05-20 09:50:14 +00:00
|
|
|
|
class ChatSerializerModel(serializers.ModelSerializer):
|
|
|
|
|
|
class Meta:
|
|
|
|
|
|
model = Chat
|
|
|
|
|
|
fields = ['id', 'application_id', 'abstract', 'client_id']
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class ChatRecordSerializer(serializers.Serializer):
|
2024-01-16 08:46:54 +00:00
|
|
|
|
class Operate(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
2024-04-25 06:05:59 +00:00
|
|
|
|
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
|
2024-01-16 08:46:54 +00:00
|
|
|
|
|
2024-04-25 08:17:29 +00:00
|
|
|
|
def is_valid(self, *, current_role=None, raise_exception=False):
|
2024-04-25 06:05:59 +00:00
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
|
|
|
|
|
application_id=self.data.get('application_id')).first()
|
|
|
|
|
|
if application_access_token is None:
|
|
|
|
|
|
raise AppApiException(500, '不存在的应用认证信息')
|
2024-04-25 08:17:29 +00:00
|
|
|
|
if not application_access_token.show_source and current_role == RoleConstants.APPLICATION_ACCESS_TOKEN.value:
|
2024-04-25 06:05:59 +00:00
|
|
|
|
raise AppApiException(500, '未开启显示知识来源')
|
|
|
|
|
|
|
2024-01-18 10:36:24 +00:00
|
|
|
|
def get_chat_record(self):
|
|
|
|
|
|
chat_record_id = self.data.get('chat_record_id')
|
|
|
|
|
|
chat_id = self.data.get('chat_id')
|
|
|
|
|
|
chat_info: ChatInfo = chat_cache.get(chat_id)
|
2024-04-25 06:05:59 +00:00
|
|
|
|
if chat_info is not None:
|
|
|
|
|
|
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
|
2024-07-01 01:45:59 +00:00
|
|
|
|
str(chat_record.id) == str(chat_record_id)]
|
2024-04-25 06:05:59 +00:00
|
|
|
|
if chat_record_list is not None and len(chat_record_list):
|
|
|
|
|
|
return chat_record_list[-1]
|
2024-01-18 10:36:24 +00:00
|
|
|
|
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
|
|
|
|
|
|
2024-04-25 08:17:29 +00:00
|
|
|
|
def one(self, current_role: RoleConstants, with_valid=True):
|
2024-01-16 08:46:54 +00:00
|
|
|
|
if with_valid:
|
2024-04-25 08:17:29 +00:00
|
|
|
|
self.is_valid(current_role=current_role, raise_exception=True)
|
2024-01-18 10:36:24 +00:00
|
|
|
|
chat_record = self.get_chat_record()
|
|
|
|
|
|
if chat_record is None:
|
|
|
|
|
|
raise AppApiException(500, "对话不存在")
|
2024-01-19 08:07:12 +00:00
|
|
|
|
return ChatRecordSerializer.Query.reset_chat_record(chat_record)
|
2024-01-16 08:46:54 +00:00
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class Query(serializers.Serializer):
|
|
|
|
|
|
application_id = serializers.UUIDField(required=True)
|
|
|
|
|
|
chat_id = serializers.UUIDField(required=True)
|
2024-05-20 09:50:14 +00:00
|
|
|
|
order_asc = serializers.BooleanField(required=False)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def list(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2024-01-16 08:46:54 +00:00
|
|
|
|
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
|
2024-05-20 09:50:14 +00:00
|
|
|
|
order_by = 'create_time' if self.data.get('order_asc') is None or self.data.get(
|
|
|
|
|
|
'order_asc') else '-create_time'
|
2023-11-16 05:16:27 +00:00
|
|
|
|
return [ChatRecordSerializerModel(chat_record).data for chat_record in
|
2024-05-20 09:50:14 +00:00
|
|
|
|
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by(order_by)]
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-01-16 08:46:54 +00:00
|
|
|
|
@staticmethod
|
2024-01-19 08:07:12 +00:00
|
|
|
|
def reset_chat_record(chat_record):
|
|
|
|
|
|
dataset_list = []
|
2024-01-16 08:46:54 +00:00
|
|
|
|
paragraph_list = []
|
2024-01-19 08:07:12 +00:00
|
|
|
|
if 'search_step' in chat_record.details and chat_record.details.get('search_step').get(
|
|
|
|
|
|
'paragraph_list') is not None:
|
|
|
|
|
|
paragraph_list = chat_record.details.get('search_step').get(
|
|
|
|
|
|
'paragraph_list')
|
|
|
|
|
|
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
|
|
|
|
|
[{row.get(
|
|
|
|
|
|
'dataset_id'): row.get(
|
|
|
|
|
|
"dataset_name")} for
|
|
|
|
|
|
row in
|
|
|
|
|
|
paragraph_list],
|
|
|
|
|
|
{}).items()]
|
2024-01-16 08:46:54 +00:00
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
**ChatRecordSerializerModel(chat_record).data,
|
|
|
|
|
|
'padding_problem_text': chat_record.details.get('problem_padding').get(
|
|
|
|
|
|
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
|
|
|
|
|
|
'dataset_list': dataset_list,
|
2024-07-01 01:45:59 +00:00
|
|
|
|
'paragraph_list': paragraph_list,
|
|
|
|
|
|
'execution_details': [chat_record.details[key] for key in chat_record.details]
|
2024-01-16 08:46:54 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
def page(self, current_page: int, page_size: int, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2024-09-24 08:52:42 +00:00
|
|
|
|
order_by = '-create_time' if self.data.get('order_asc') is None or self.data.get(
|
|
|
|
|
|
'order_asc') else 'create_time'
|
2024-01-16 08:46:54 +00:00
|
|
|
|
page = page_search(current_page, page_size,
|
2024-05-20 09:50:14 +00:00
|
|
|
|
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by(order_by),
|
2024-01-19 08:07:12 +00:00
|
|
|
|
post_records_handler=lambda chat_record: self.reset_chat_record(chat_record))
|
2024-01-16 08:46:54 +00:00
|
|
|
|
return page
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
class Vote(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
vote_status = serializers.ChoiceField(choices=VoteChoices.choices, error_messages=ErrMessage.uuid("投标状态"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def vote(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
if not try_lock(self.data.get('chat_record_id')):
|
|
|
|
|
|
raise AppApiException(500, "正在对当前会话纪要进行投票中,请勿重复发送请求")
|
|
|
|
|
|
try:
|
|
|
|
|
|
chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'),
|
|
|
|
|
|
chat_id=self.data.get('chat_id'))
|
|
|
|
|
|
if chat_record_details_model is None:
|
|
|
|
|
|
raise AppApiException(500, "不存在的对话 chat_record_id")
|
|
|
|
|
|
vote_status = self.data.get("vote_status")
|
|
|
|
|
|
if chat_record_details_model.vote_status == VoteChoices.UN_VOTE:
|
|
|
|
|
|
if vote_status == VoteChoices.STAR:
|
|
|
|
|
|
# 点赞
|
|
|
|
|
|
chat_record_details_model.vote_status = VoteChoices.STAR
|
|
|
|
|
|
|
|
|
|
|
|
if vote_status == VoteChoices.TRAMPLE:
|
|
|
|
|
|
# 点踩
|
|
|
|
|
|
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
|
|
|
|
|
|
chat_record_details_model.save()
|
|
|
|
|
|
else:
|
|
|
|
|
|
if vote_status == VoteChoices.UN_VOTE:
|
|
|
|
|
|
# 取消点赞
|
|
|
|
|
|
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
|
|
|
|
|
|
chat_record_details_model.save()
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
|
|
|
|
|
|
finally:
|
|
|
|
|
|
un_lock(self.data.get('chat_record_id'))
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
class ImproveSerializer(serializers.Serializer):
|
2024-05-28 10:16:01 +00:00
|
|
|
|
title = serializers.CharField(required=False, max_length=256, allow_null=True, allow_blank=True,
|
2024-03-04 02:12:18 +00:00
|
|
|
|
error_messages=ErrMessage.char("段落标题"))
|
|
|
|
|
|
content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-05-28 10:16:01 +00:00
|
|
|
|
problem_text = serializers.CharField(required=False, max_length=256, allow_null=True, allow_blank=True,
|
2024-05-24 03:22:50 +00:00
|
|
|
|
error_messages=ErrMessage.char("问题"))
|
|
|
|
|
|
|
2023-12-13 09:14:43 +00:00
|
|
|
|
class ParagraphModel(serializers.ModelSerializer):
|
|
|
|
|
|
class Meta:
|
|
|
|
|
|
model = Paragraph
|
|
|
|
|
|
fields = "__all__"
|
|
|
|
|
|
|
|
|
|
|
|
class ChatRecordImprove(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
2023-12-13 09:14:43 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
|
2023-12-13 09:14:43 +00:00
|
|
|
|
|
|
|
|
|
|
def get(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
chat_record_id = self.data.get('chat_record_id')
|
|
|
|
|
|
chat_id = self.data.get('chat_id')
|
|
|
|
|
|
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
|
|
|
|
|
if chat_record is None:
|
|
|
|
|
|
raise AppApiException(500, '不存在的对话记录')
|
|
|
|
|
|
if chat_record.improve_paragraph_id_list is None or len(chat_record.improve_paragraph_id_list) == 0:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list)
|
|
|
|
|
|
if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list):
|
|
|
|
|
|
paragraph_model_id_list = [str(p.id) for p in paragraph_model_list]
|
|
|
|
|
|
chat_record.improve_paragraph_id_list = list(
|
|
|
|
|
|
filter(lambda p_id: paragraph_model_id_list.__contains__(p_id),
|
|
|
|
|
|
chat_record.improve_paragraph_id_list))
|
|
|
|
|
|
chat_record.save()
|
|
|
|
|
|
return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list]
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class Improve(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 09:01:16 +00:00
|
|
|
|
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if not QuerySet(Document).filter(id=self.data.get('document_id'),
|
|
|
|
|
|
dataset_id=self.data.get('dataset_id')).exists():
|
|
|
|
|
|
raise AppApiException(500, "文档id不正确")
|
|
|
|
|
|
|
2024-02-19 02:55:13 +00:00
|
|
|
|
@staticmethod
|
2024-07-18 02:26:16 +00:00
|
|
|
|
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
|
2024-08-21 06:46:11 +00:00
|
|
|
|
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
2024-02-19 02:55:13 +00:00
|
|
|
|
# 发送向量化事件
|
2024-08-21 06:46:11 +00:00
|
|
|
|
embedding_by_paragraph(paragraph_id, model_id)
|
2024-02-19 02:55:13 +00:00
|
|
|
|
return chat_record
|
|
|
|
|
|
|
|
|
|
|
|
@post(post_function=post_embedding_paragraph)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def improve(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
ChatRecordSerializer.ImproveSerializer(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
chat_record_id = self.data.get('chat_record_id')
|
|
|
|
|
|
chat_id = self.data.get('chat_id')
|
|
|
|
|
|
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
|
|
|
|
|
if chat_record is None:
|
|
|
|
|
|
raise AppApiException(500, '不存在的对话记录')
|
|
|
|
|
|
|
|
|
|
|
|
document_id = self.data.get("document_id")
|
|
|
|
|
|
dataset_id = self.data.get("dataset_id")
|
|
|
|
|
|
paragraph = Paragraph(id=uuid.uuid1(),
|
|
|
|
|
|
document_id=document_id,
|
|
|
|
|
|
content=instance.get("content"),
|
|
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
|
|
title=instance.get("title") if 'title' in instance else '')
|
2024-05-24 03:22:50 +00:00
|
|
|
|
problem_text = instance.get('problem_text') if instance.get(
|
|
|
|
|
|
'problem_text') is not None else chat_record.problem_text
|
|
|
|
|
|
problem = Problem(id=uuid.uuid1(), content=problem_text, dataset_id=dataset_id)
|
2024-03-29 10:27:08 +00:00
|
|
|
|
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), dataset_id=dataset_id,
|
|
|
|
|
|
document_id=document_id,
|
|
|
|
|
|
problem_id=problem.id,
|
|
|
|
|
|
paragraph_id=paragraph.id)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
# 插入问题
|
|
|
|
|
|
problem.save()
|
|
|
|
|
|
# 插入段落
|
|
|
|
|
|
paragraph.save()
|
2024-03-29 10:27:08 +00:00
|
|
|
|
# 插入关联问题
|
|
|
|
|
|
problem_paragraph_mapping.save()
|
2023-12-13 06:10:51 +00:00
|
|
|
|
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
# 添加标注
|
|
|
|
|
|
chat_record.save()
|
2024-07-18 02:26:16 +00:00
|
|
|
|
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
|
2024-03-04 09:01:16 +00:00
|
|
|
|
|
|
|
|
|
|
class Operate(serializers.Serializer):
|
|
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
|
|
|
|
|
|
|
|
|
|
|
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
|
|
|
|
|
|
|
|
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
|
|
|
|
|
|
|
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
|
|
|
|
|
|
|
|
|
|
|
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
|
|
|
|
|
|
|
|
|
|
|
def delete(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
chat_record_id = self.data.get('chat_record_id')
|
|
|
|
|
|
chat_id = self.data.get('chat_id')
|
|
|
|
|
|
dataset_id = self.data.get('dataset_id')
|
|
|
|
|
|
document_id = self.data.get('document_id')
|
|
|
|
|
|
paragraph_id = self.data.get('paragraph_id')
|
|
|
|
|
|
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
|
|
|
|
|
if chat_record is None:
|
|
|
|
|
|
raise AppApiException(500, '不存在的对话记录')
|
|
|
|
|
|
if not chat_record.improve_paragraph_id_list.__contains__(uuid.UUID(paragraph_id)):
|
|
|
|
|
|
raise AppApiException(500, f'段落id错误,当前对话记录不存在【{paragraph_id}】段落id')
|
|
|
|
|
|
chat_record.improve_paragraph_id_list = [row for row in chat_record.improve_paragraph_id_list if
|
|
|
|
|
|
str(row) != paragraph_id]
|
|
|
|
|
|
chat_record.save()
|
|
|
|
|
|
o = ParagraphSerializers.Operate(
|
|
|
|
|
|
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
|
|
|
|
|
|
o.is_valid(raise_exception=True)
|
|
|
|
|
|
return o.delete()
|
2024-11-12 09:49:22 +00:00
|
|
|
|
|
|
|
|
|
|
class PostImprove(serializers.Serializer):
|
|
|
|
|
|
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
|
|
|
|
|
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
|
|
|
|
|
chat_ids = serializers.ListSerializer(child=serializers.UUIDField(), required=True,
|
|
|
|
|
|
error_messages=ErrMessage.list("对话id"))
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if not Document.objects.filter(id=self.data['document_id'], dataset_id=self.data['dataset_id']).exists():
|
|
|
|
|
|
raise AppApiException(500, "文档id不正确")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def post_embedding_paragraph(paragraph_ids, dataset_id):
|
|
|
|
|
|
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
|
|
|
|
|
embedding_by_paragraph_list(paragraph_ids, model_id)
|
|
|
|
|
|
|
|
|
|
|
|
@post(post_function=post_embedding_paragraph)
|
|
|
|
|
|
@transaction.atomic
|
|
|
|
|
|
def post_improve(self, instance: Dict):
|
|
|
|
|
|
ChatRecordSerializer.PostImprove(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
chat_ids = instance['chat_ids']
|
|
|
|
|
|
document_id = instance['document_id']
|
|
|
|
|
|
dataset_id = instance['dataset_id']
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有聊天记录
|
|
|
|
|
|
chat_record_list = list(ChatRecord.objects.filter(chat_id__in=chat_ids))
|
|
|
|
|
|
if len(chat_record_list) < len(chat_ids):
|
|
|
|
|
|
raise AppApiException(500, "存在不存在的对话记录")
|
|
|
|
|
|
|
|
|
|
|
|
# 批量创建段落和问题映射
|
|
|
|
|
|
paragraphs = []
|
|
|
|
|
|
paragraph_ids = []
|
|
|
|
|
|
problem_paragraph_mappings = []
|
|
|
|
|
|
for chat_record in chat_record_list:
|
|
|
|
|
|
paragraph = Paragraph(
|
|
|
|
|
|
id=uuid.uuid1(),
|
|
|
|
|
|
document_id=document_id,
|
|
|
|
|
|
content=chat_record.answer_text,
|
|
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
|
|
title=chat_record.problem_text
|
|
|
|
|
|
)
|
|
|
|
|
|
problem, _ = Problem.objects.get_or_create(content=chat_record.problem_text, dataset_id=dataset_id)
|
|
|
|
|
|
problem_paragraph_mapping = ProblemParagraphMapping(
|
|
|
|
|
|
id=uuid.uuid1(),
|
|
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
|
|
document_id=document_id,
|
|
|
|
|
|
problem_id=problem.id,
|
|
|
|
|
|
paragraph_id=paragraph.id
|
|
|
|
|
|
)
|
|
|
|
|
|
paragraphs.append(paragraph)
|
|
|
|
|
|
paragraph_ids.append(paragraph.id)
|
|
|
|
|
|
problem_paragraph_mappings.append(problem_paragraph_mapping)
|
|
|
|
|
|
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
|
|
|
|
|
|
|
|
|
|
|
# 批量保存段落和问题映射
|
|
|
|
|
|
Paragraph.objects.bulk_create(paragraphs)
|
|
|
|
|
|
ProblemParagraphMapping.objects.bulk_create(problem_paragraph_mappings)
|
|
|
|
|
|
|
|
|
|
|
|
# 批量保存聊天记录
|
|
|
|
|
|
ChatRecord.objects.bulk_update(chat_record_list, ['improve_paragraph_id_list'])
|
|
|
|
|
|
|
|
|
|
|
|
return paragraph_ids, dataset_id
|