2023-11-16 05:16:27 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: chat_message_serializers.py
|
|
|
|
|
|
@date:2023/11/14 13:51
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
2024-07-01 01:45:59 +00:00
|
|
|
|
import uuid
|
2024-09-09 06:47:25 +00:00
|
|
|
|
from typing import List, Dict
|
2024-01-16 08:46:54 +00:00
|
|
|
|
from uuid import UUID
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-06 05:43:45 +00:00
|
|
|
|
from django.core.cache import caches
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from django.db.models import QuerySet
|
2024-01-16 08:46:54 +00:00
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
2024-04-15 05:45:23 +00:00
|
|
|
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
2024-01-16 08:46:54 +00:00
|
|
|
|
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
|
|
|
|
|
|
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
|
|
|
|
|
|
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
|
|
|
|
|
|
BaseGenerateHumanMessageStep
|
|
|
|
|
|
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
|
|
|
|
|
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
2024-07-01 01:45:59 +00:00
|
|
|
|
from application.flow.i_step_node import WorkFlowPostHandler
|
|
|
|
|
|
from application.flow.workflow_manage import WorkflowManage, Flow
|
|
|
|
|
|
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices, \
|
|
|
|
|
|
WorkFlowVersion
|
2024-03-13 21:43:01 +00:00
|
|
|
|
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
|
|
|
|
|
|
from common.constants.authentication_type import AuthenticationType
|
2024-09-09 06:47:25 +00:00
|
|
|
|
from common.exception.app_exception import AppChatNumOutOfBoundsFailed, ChatException
|
|
|
|
|
|
from common.handle.base_to_response import BaseToResponse
|
|
|
|
|
|
from common.handle.impl.response.openai_to_response import OpenaiToResponse
|
|
|
|
|
|
from common.handle.impl.response.system_to_response import SystemToResponse
|
2024-03-13 08:07:13 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2024-01-16 08:46:54 +00:00
|
|
|
|
from common.util.split_model import flat_map
|
|
|
|
|
|
from dataset.models import Paragraph, Document
|
2024-03-22 12:13:04 +00:00
|
|
|
|
from setting.models import Model, Status
|
2024-08-23 09:46:05 +00:00
|
|
|
|
from setting.models_provider import get_model_credential
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-07-25 02:41:38 +00:00
|
|
|
|
chat_cache = caches['chat_cache']
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatInfo:
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
|
chat_id: str,
|
|
|
|
|
|
dataset_id_list: List[str],
|
|
|
|
|
|
exclude_document_id_list: list[str],
|
2024-07-01 01:45:59 +00:00
|
|
|
|
application: Application,
|
|
|
|
|
|
work_flow_version: WorkFlowVersion = None):
|
2024-01-16 08:46:54 +00:00
|
|
|
|
"""
|
|
|
|
|
|
:param chat_id: 对话id
|
|
|
|
|
|
:param dataset_id_list: 数据集列表
|
|
|
|
|
|
:param exclude_document_id_list: 排除的文档
|
|
|
|
|
|
:param application: 应用信息
|
|
|
|
|
|
"""
|
2023-11-16 05:16:27 +00:00
|
|
|
|
self.chat_id = chat_id
|
2024-01-16 08:46:54 +00:00
|
|
|
|
self.application = application
|
2023-11-16 05:16:27 +00:00
|
|
|
|
self.dataset_id_list = dataset_id_list
|
|
|
|
|
|
self.exclude_document_id_list = exclude_document_id_list
|
2024-01-16 08:46:54 +00:00
|
|
|
|
self.chat_record_list: List[ChatRecord] = []
|
2024-07-01 01:45:59 +00:00
|
|
|
|
self.work_flow_version = work_flow_version
|
2024-01-16 08:46:54 +00:00
|
|
|
|
|
2024-09-14 13:48:45 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_no_references_setting(dataset_setting, model_setting):
|
|
|
|
|
|
no_references_setting = dataset_setting.get(
|
|
|
|
|
|
'no_references_setting', {
|
|
|
|
|
|
'status': 'ai_questioning',
|
|
|
|
|
|
'value': '{question}'})
|
|
|
|
|
|
if no_references_setting.get('status') == 'ai_questioning':
|
|
|
|
|
|
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
|
|
|
|
|
|
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
|
|
|
|
|
|
return no_references_setting
|
|
|
|
|
|
|
2024-01-16 08:46:54 +00:00
|
|
|
|
def to_base_pipeline_manage_params(self):
|
|
|
|
|
|
dataset_setting = self.application.dataset_setting
|
|
|
|
|
|
model_setting = self.application.model_setting
|
2024-08-23 09:46:05 +00:00
|
|
|
|
model_id = self.application.model.id if self.application.model is not None else None
|
|
|
|
|
|
model_params_setting = None
|
|
|
|
|
|
if model_id is not None:
|
|
|
|
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
|
|
|
|
|
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
|
|
|
|
|
model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
|
2024-01-16 08:46:54 +00:00
|
|
|
|
return {
|
|
|
|
|
|
'dataset_id_list': self.dataset_id_list,
|
|
|
|
|
|
'exclude_document_id_list': self.exclude_document_id_list,
|
|
|
|
|
|
'exclude_paragraph_id_list': [],
|
|
|
|
|
|
'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
|
|
|
|
|
|
'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
|
|
|
|
|
|
'max_paragraph_char_number': dataset_setting.get(
|
|
|
|
|
|
'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
|
|
|
|
|
|
'history_chat_record': self.chat_record_list,
|
|
|
|
|
|
'chat_id': self.chat_id,
|
|
|
|
|
|
'dialogue_number': self.application.dialogue_number,
|
2024-09-14 13:48:45 +00:00
|
|
|
|
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
|
|
|
|
|
|
self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中',
|
2024-01-16 08:46:54 +00:00
|
|
|
|
'prompt': model_setting.get(
|
2024-09-14 13:48:45 +00:00
|
|
|
|
'prompt') if 'prompt' in model_setting and len(model_setting.get(
|
|
|
|
|
|
'prompt')) > 0 else Application.get_default_model_prompt(),
|
|
|
|
|
|
'system': model_setting.get(
|
|
|
|
|
|
'system', None),
|
2024-08-23 09:46:05 +00:00
|
|
|
|
'model_id': model_id,
|
2024-02-21 10:10:18 +00:00
|
|
|
|
'problem_optimization': self.application.problem_optimization,
|
2024-03-13 21:43:01 +00:00
|
|
|
|
'stream': True,
|
2024-08-23 09:46:05 +00:00
|
|
|
|
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
|
|
|
|
|
|
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
|
2024-04-22 03:21:24 +00:00
|
|
|
|
'search_mode': self.application.dataset_setting.get(
|
2024-04-24 07:03:58 +00:00
|
|
|
|
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
|
2024-09-14 13:48:45 +00:00
|
|
|
|
'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting),
|
2024-07-19 02:34:47 +00:00
|
|
|
|
'user_id': self.application.user_id
|
2024-01-16 08:46:54 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
2024-03-13 21:43:01 +00:00
|
|
|
|
exclude_paragraph_id_list, client_id: str, client_type, stream=True):
|
2024-01-16 08:46:54 +00:00
|
|
|
|
params = self.to_base_pipeline_manage_params()
|
|
|
|
|
|
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
2024-03-13 21:43:01 +00:00
|
|
|
|
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id,
|
|
|
|
|
|
'client_type': client_type}
|
2024-01-16 08:46:54 +00:00
|
|
|
|
|
2024-03-28 08:06:54 +00:00
|
|
|
|
def append_chat_record(self, chat_record: ChatRecord, client_id=None):
|
2024-09-26 05:24:22 +00:00
|
|
|
|
chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
|
|
|
|
|
|
chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 存入缓存中
|
|
|
|
|
|
self.chat_record_list.append(chat_record)
|
|
|
|
|
|
if self.application.id is not None:
|
2023-12-21 08:55:11 +00:00
|
|
|
|
# 插入数据库
|
2024-01-16 08:46:54 +00:00
|
|
|
|
if not QuerySet(Chat).filter(id=self.chat_id).exists():
|
2024-10-28 05:43:40 +00:00
|
|
|
|
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
|
2024-03-28 08:06:54 +00:00
|
|
|
|
client_id=client_id).save()
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 插入会话记录
|
|
|
|
|
|
chat_record.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_post_handler(chat_info: ChatInfo):
|
|
|
|
|
|
class PostHandler(PostResponseHandler):
|
|
|
|
|
|
|
|
|
|
|
|
def handler(self,
|
|
|
|
|
|
chat_id: UUID,
|
|
|
|
|
|
chat_record_id,
|
|
|
|
|
|
paragraph_list: List[Paragraph],
|
|
|
|
|
|
problem_text: str,
|
|
|
|
|
|
answer_text,
|
2024-04-15 05:45:23 +00:00
|
|
|
|
manage: PipelineManage,
|
2024-01-16 08:46:54 +00:00
|
|
|
|
step: BaseChatStep,
|
|
|
|
|
|
padding_problem_text: str = None,
|
2024-03-28 08:06:54 +00:00
|
|
|
|
client_id=None,
|
2024-01-16 08:46:54 +00:00
|
|
|
|
**kwargs):
|
|
|
|
|
|
chat_record = ChatRecord(id=chat_record_id,
|
|
|
|
|
|
chat_id=chat_id,
|
|
|
|
|
|
problem_text=problem_text,
|
|
|
|
|
|
answer_text=answer_text,
|
|
|
|
|
|
details=manage.get_details(),
|
|
|
|
|
|
message_tokens=manage.context['message_tokens'],
|
|
|
|
|
|
answer_tokens=manage.context['answer_tokens'],
|
|
|
|
|
|
run_time=manage.context['run_time'],
|
|
|
|
|
|
index=len(chat_info.chat_record_list) + 1)
|
2024-03-28 08:06:54 +00:00
|
|
|
|
chat_info.append_chat_record(chat_record, client_id)
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 重新设置缓存
|
|
|
|
|
|
chat_cache.set(chat_id,
|
|
|
|
|
|
chat_info, timeout=60 * 30)
|
|
|
|
|
|
|
|
|
|
|
|
return PostHandler()
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-09-09 06:47:25 +00:00
|
|
|
|
class OpenAIMessage(serializers.Serializer):
|
|
|
|
|
|
content = serializers.CharField(required=True, error_messages=ErrMessage.char('内容'))
|
|
|
|
|
|
role = serializers.CharField(required=True, error_messages=ErrMessage.char('角色'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIInstanceSerializer(serializers.Serializer):
|
|
|
|
|
|
messages = serializers.ListField(child=OpenAIMessage())
|
|
|
|
|
|
chat_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("对话id"))
|
|
|
|
|
|
re_chat = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("重新生成"))
|
|
|
|
|
|
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("流式输出"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIChatSerializer(serializers.Serializer):
|
|
|
|
|
|
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
|
|
|
|
|
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
|
|
|
|
|
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_message(instance):
|
|
|
|
|
|
return instance.get('messages')[-1].get('content')
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def generate_chat(chat_id, application_id, message, client_id):
|
|
|
|
|
|
if chat_id is None:
|
|
|
|
|
|
chat_id = str(uuid.uuid1())
|
|
|
|
|
|
chat = QuerySet(Chat).filter(id=chat_id).first()
|
|
|
|
|
|
if chat is None:
|
2024-09-26 05:24:22 +00:00
|
|
|
|
Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id).save()
|
2024-09-09 06:47:25 +00:00
|
|
|
|
return chat_id
|
|
|
|
|
|
|
|
|
|
|
|
def chat(self, instance: Dict, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
chat_id = instance.get('chat_id')
|
|
|
|
|
|
message = self.get_message(instance)
|
|
|
|
|
|
re_chat = instance.get('re_chat', False)
|
|
|
|
|
|
stream = instance.get('stream', False)
|
|
|
|
|
|
application_id = self.data.get('application_id')
|
|
|
|
|
|
client_id = self.data.get('client_id')
|
|
|
|
|
|
client_type = self.data.get('client_type')
|
|
|
|
|
|
chat_id = self.generate_chat(chat_id, application_id, message, client_id)
|
|
|
|
|
|
return ChatMessageSerializer(
|
|
|
|
|
|
data={'chat_id': chat_id, 'message': message,
|
|
|
|
|
|
're_chat': re_chat,
|
|
|
|
|
|
'stream': stream,
|
|
|
|
|
|
'application_id': application_id,
|
|
|
|
|
|
'client_id': client_id,
|
2024-09-25 09:21:23 +00:00
|
|
|
|
'client_type': client_type, 'form_data': instance.get('form_data', {})}).chat(
|
2024-09-09 06:47:25 +00:00
|
|
|
|
base_to_response=OpenaiToResponse())
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class ChatMessageSerializer(serializers.Serializer):
|
2024-11-13 02:37:16 +00:00
|
|
|
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
2024-06-13 06:55:35 +00:00
|
|
|
|
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
|
2024-03-13 21:43:01 +00:00
|
|
|
|
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
|
|
|
|
|
|
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
|
2024-11-13 02:37:16 +00:00
|
|
|
|
chat_record_id = serializers.UUIDField(required=False, allow_null=True,
|
|
|
|
|
|
error_messages=ErrMessage.uuid("对话记录id"))
|
|
|
|
|
|
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
|
|
|
|
|
error_messages=ErrMessage.char("节点id"))
|
|
|
|
|
|
node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数"))
|
2024-03-13 21:43:01 +00:00
|
|
|
|
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
|
|
|
|
|
|
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
|
|
|
|
|
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
2024-09-09 11:10:16 +00:00
|
|
|
|
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
|
2024-11-11 07:48:56 +00:00
|
|
|
|
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
|
2024-03-13 21:43:01 +00:00
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
def is_valid_application_workflow(self, *, raise_exception=False):
|
|
|
|
|
|
self.is_valid_intraday_access_num()
|
|
|
|
|
|
|
2024-09-09 06:47:25 +00:00
|
|
|
|
def is_valid_chat_id(self, chat_info: ChatInfo):
|
2024-09-09 10:45:44 +00:00
|
|
|
|
if self.data.get('application_id') is not None and self.data.get('application_id') != str(
|
|
|
|
|
|
chat_info.application.id):
|
2024-09-09 06:47:25 +00:00
|
|
|
|
raise ChatException(500, "会话不存在")
|
|
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
def is_valid_intraday_access_num(self):
|
2024-03-13 21:43:01 +00:00
|
|
|
|
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
|
|
|
|
|
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
|
|
|
|
|
|
if access_client is None:
|
|
|
|
|
|
access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'),
|
|
|
|
|
|
application_id=self.data.get('application_id'),
|
|
|
|
|
|
access_num=0,
|
|
|
|
|
|
intraday_access_num=0)
|
|
|
|
|
|
access_client.save()
|
|
|
|
|
|
|
|
|
|
|
|
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
|
|
|
|
|
application_id=self.data.get('application_id')).first()
|
|
|
|
|
|
if application_access_token.access_num <= access_client.intraday_access_num:
|
|
|
|
|
|
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False):
|
|
|
|
|
|
self.is_valid_intraday_access_num()
|
2024-03-22 14:23:39 +00:00
|
|
|
|
model = chat_info.application.model
|
2024-04-28 09:09:12 +00:00
|
|
|
|
if model is None:
|
|
|
|
|
|
return chat_info
|
2024-03-22 14:23:39 +00:00
|
|
|
|
model = QuerySet(Model).filter(id=model.id).first()
|
|
|
|
|
|
if model is None:
|
2024-04-28 09:09:12 +00:00
|
|
|
|
return chat_info
|
2024-03-22 14:23:39 +00:00
|
|
|
|
if model.status == Status.ERROR:
|
2024-09-09 06:47:25 +00:00
|
|
|
|
raise ChatException(500, "当前模型不可用")
|
2024-03-22 14:23:39 +00:00
|
|
|
|
if model.status == Status.DOWNLOAD:
|
2024-09-09 06:47:25 +00:00
|
|
|
|
raise ChatException(500, "模型正在下载中,请稍后再发起对话")
|
2024-03-22 14:23:39 +00:00
|
|
|
|
return chat_info
|
2024-03-13 21:43:01 +00:00
|
|
|
|
|
2024-09-09 06:47:25 +00:00
|
|
|
|
def chat_simple(self, chat_info: ChatInfo, base_to_response):
|
2024-03-13 21:43:01 +00:00
|
|
|
|
message = self.data.get('message')
|
|
|
|
|
|
re_chat = self.data.get('re_chat')
|
|
|
|
|
|
stream = self.data.get('stream')
|
|
|
|
|
|
client_id = self.data.get('client_id')
|
|
|
|
|
|
client_type = self.data.get('client_type')
|
2024-04-15 11:06:42 +00:00
|
|
|
|
pipeline_manage_builder = PipelineManage.builder()
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 如果开启了问题优化,则添加上问题优化步骤
|
|
|
|
|
|
if chat_info.application.problem_optimization:
|
2024-04-15 11:06:42 +00:00
|
|
|
|
pipeline_manage_builder.append_step(BaseResetProblemStep)
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 构建流水线管理器
|
2024-04-15 11:06:42 +00:00
|
|
|
|
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
|
2024-04-22 03:21:24 +00:00
|
|
|
|
.append_step(BaseGenerateHumanMessageStep)
|
|
|
|
|
|
.append_step(BaseChatStep)
|
2024-09-09 06:47:25 +00:00
|
|
|
|
.add_base_to_response(base_to_response)
|
2024-04-22 03:21:24 +00:00
|
|
|
|
.build())
|
2024-01-16 08:46:54 +00:00
|
|
|
|
exclude_paragraph_id_list = []
|
|
|
|
|
|
# 相同问题是否需要排除已经查询到的段落
|
|
|
|
|
|
if re_chat:
|
2024-04-25 02:44:14 +00:00
|
|
|
|
paragraph_id_list = flat_map(
|
|
|
|
|
|
[[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
|
|
|
|
|
|
chat_record in chat_info.chat_record_list if
|
|
|
|
|
|
chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
|
|
|
|
|
|
chat_record.details['search_step']])
|
2024-01-16 08:46:54 +00:00
|
|
|
|
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
|
|
|
|
|
# 构建运行参数
|
2024-02-21 10:10:18 +00:00
|
|
|
|
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
|
2024-03-13 21:43:01 +00:00
|
|
|
|
client_id, client_type, stream)
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 运行流水线作业
|
2024-04-15 11:06:42 +00:00
|
|
|
|
pipeline_message.run(params)
|
|
|
|
|
|
return pipeline_message.context['chat_result']
|
2024-01-16 08:46:54 +00:00
|
|
|
|
|
2024-11-13 02:37:16 +00:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_chat_record(chat_info, chat_record_id):
|
|
|
|
|
|
if chat_info is not None:
|
|
|
|
|
|
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
|
|
|
|
|
|
str(chat_record.id) == str(chat_record_id)]
|
|
|
|
|
|
if chat_record_list is not None and len(chat_record_list):
|
|
|
|
|
|
return chat_record_list[-1]
|
|
|
|
|
|
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first()
|
|
|
|
|
|
if chat_record is None:
|
|
|
|
|
|
raise ChatException(500, "对话纪要不存在")
|
|
|
|
|
|
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
|
|
|
|
|
|
return chat_record
|
|
|
|
|
|
|
2024-09-09 06:47:25 +00:00
|
|
|
|
def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
|
2024-07-01 01:45:59 +00:00
|
|
|
|
message = self.data.get('message')
|
|
|
|
|
|
re_chat = self.data.get('re_chat')
|
|
|
|
|
|
stream = self.data.get('stream')
|
|
|
|
|
|
client_id = self.data.get('client_id')
|
|
|
|
|
|
client_type = self.data.get('client_type')
|
2024-09-09 11:10:16 +00:00
|
|
|
|
form_data = self.data.get('form_data')
|
2024-11-11 07:48:56 +00:00
|
|
|
|
image_list = self.data.get('image_list')
|
2024-07-18 11:11:29 +00:00
|
|
|
|
user_id = chat_info.application.user_id
|
2024-11-13 02:37:16 +00:00
|
|
|
|
chat_record_id = self.data.get('chat_record_id')
|
|
|
|
|
|
chat_record = None
|
|
|
|
|
|
if chat_record_id is not None:
|
|
|
|
|
|
chat_record = self.get_chat_record(chat_info, chat_record_id)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
|
|
|
|
|
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
2024-11-13 02:37:16 +00:00
|
|
|
|
'chat_id': chat_info.chat_id, 'chat_record_id': str(
|
|
|
|
|
|
uuid.uuid1()) if chat_record is None else chat_record.id,
|
2024-07-01 01:45:59 +00:00
|
|
|
|
'stream': stream,
|
2024-07-18 11:11:29 +00:00
|
|
|
|
're_chat': re_chat,
|
2024-11-11 10:34:54 +00:00
|
|
|
|
'client_id': client_id,
|
|
|
|
|
|
'client_type': client_type,
|
2024-09-09 06:47:25 +00:00
|
|
|
|
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
|
2024-11-13 02:37:16 +00:00
|
|
|
|
base_to_response, form_data, image_list, self.data.get('runtime_node_id'),
|
|
|
|
|
|
self.data.get('node_data'), chat_record)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
r = work_flow_manage.run()
|
|
|
|
|
|
return r
|
|
|
|
|
|
|
2024-09-09 06:47:25 +00:00
|
|
|
|
def chat(self, base_to_response: BaseToResponse = SystemToResponse()):
|
2024-07-01 01:45:59 +00:00
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
chat_info = self.get_chat_info()
|
2024-09-09 06:47:25 +00:00
|
|
|
|
self.is_valid_chat_id(chat_info)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
|
|
|
|
|
|
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
|
2024-09-09 06:47:25 +00:00
|
|
|
|
return self.chat_simple(chat_info, base_to_response)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
else:
|
|
|
|
|
|
self.is_valid_application_workflow(raise_exception=True)
|
2024-09-09 06:47:25 +00:00
|
|
|
|
return self.chat_work_flow(chat_info, base_to_response)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
def get_chat_info(self):
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
chat_id = self.data.get('chat_id')
|
|
|
|
|
|
chat_info: ChatInfo = chat_cache.get(chat_id)
|
|
|
|
|
|
if chat_info is None:
|
|
|
|
|
|
chat_info: ChatInfo = self.re_open_chat(chat_id)
|
|
|
|
|
|
chat_cache.set(chat_id,
|
|
|
|
|
|
chat_info, timeout=60 * 30)
|
|
|
|
|
|
return chat_info
|
|
|
|
|
|
|
|
|
|
|
|
def re_open_chat(self, chat_id: str):
|
2024-01-16 08:46:54 +00:00
|
|
|
|
chat = QuerySet(Chat).filter(id=chat_id).first()
|
|
|
|
|
|
if chat is None:
|
2024-09-09 06:47:25 +00:00
|
|
|
|
raise ChatException(500, "会话不存在")
|
2024-01-16 08:46:54 +00:00
|
|
|
|
application = QuerySet(Application).filter(id=chat.application_id).first()
|
|
|
|
|
|
if application is None:
|
2024-09-09 06:47:25 +00:00
|
|
|
|
raise ChatException(500, "应用不存在")
|
2024-07-01 01:45:59 +00:00
|
|
|
|
if application.type == ApplicationTypeChoices.SIMPLE:
|
|
|
|
|
|
return self.re_open_chat_simple(chat_id, application)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self.re_open_chat_work_flow(chat_id, application)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def re_open_chat_simple(chat_id, application):
|
2024-01-16 08:46:54 +00:00
|
|
|
|
# 数据集id列表
|
|
|
|
|
|
dataset_id_list = [str(row.dataset_id) for row in
|
|
|
|
|
|
QuerySet(ApplicationDatasetMapping).filter(
|
|
|
|
|
|
application_id=application.id)]
|
|
|
|
|
|
|
|
|
|
|
|
# 需要排除的文档
|
|
|
|
|
|
exclude_document_id_list = [str(document.id) for document in
|
|
|
|
|
|
QuerySet(Document).filter(
|
|
|
|
|
|
dataset_id__in=dataset_id_list,
|
|
|
|
|
|
is_active=False)]
|
2024-09-05 07:37:12 +00:00
|
|
|
|
chat_info = ChatInfo(chat_id, dataset_id_list, exclude_document_id_list, application)
|
|
|
|
|
|
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
|
|
|
|
|
|
chat_record_list.sort(key=lambda r: r.create_time)
|
|
|
|
|
|
for chat_record in chat_record_list:
|
|
|
|
|
|
chat_info.chat_record_list.append(chat_record)
|
|
|
|
|
|
return chat_info
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def re_open_chat_work_flow(chat_id, application):
|
|
|
|
|
|
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by(
|
|
|
|
|
|
'-create_time')[0:1].first()
|
|
|
|
|
|
if work_flow_version is None:
|
2024-09-09 06:47:25 +00:00
|
|
|
|
raise ChatException(500, "应用未发布,请发布后再使用")
|
2024-09-05 07:37:12 +00:00
|
|
|
|
|
|
|
|
|
|
chat_info = ChatInfo(chat_id, [], [], application, work_flow_version)
|
|
|
|
|
|
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
|
|
|
|
|
|
chat_record_list.sort(key=lambda r: r.create_time)
|
|
|
|
|
|
for chat_record in chat_record_list:
|
|
|
|
|
|
chat_info.chat_record_list.append(chat_record)
|
|
|
|
|
|
return chat_info
|