UnisKB/apps/application/serializers/common.py

144 lines
7.2 KiB
Python
Raw Normal View History

2025-06-09 08:18:43 +00:00
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file common.py
@date2025/6/9 13:42
@desc:
"""
from datetime import datetime
from typing import List
from django.core.cache import cache
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.models import Application, WorkFlowVersion, ChatRecord, Chat
from common.constants.cache_version import Cache_Version
from models_provider.models import Model
from models_provider.tools import get_model_credential
class ChatInfo:
def __init__(self,
chat_id: str,
chat_user_id: str,
chat_user_type: str,
knowledge_id_list: List[str],
exclude_document_id_list: list[str],
2025-06-12 07:21:59 +00:00
application_id: str,
2025-06-09 08:18:43 +00:00
application: Application,
2025-06-12 07:21:59 +00:00
work_flow_version: WorkFlowVersion = None,
debug=False):
2025-06-09 08:18:43 +00:00
"""
2025-06-12 07:21:59 +00:00
:param chat_id: 对话id
2025-06-09 08:18:43 +00:00
:param chat_user_id 对话用户id
:param chat_user_type 对话用户类型
:param knowledge_id_list: 知识库列表
:param exclude_document_id_list: 排除的文档
2025-06-12 07:21:59 +00:00
:param application_id 应用id
2025-06-09 08:18:43 +00:00
:param application: 应用信息
2025-06-12 07:21:59 +00:00
:param debug 是否是调试
2025-06-09 08:18:43 +00:00
"""
self.chat_id = chat_id
self.chat_user_id = chat_user_id
self.chat_user_type = chat_user_type
self.application = application
self.knowledge_id_list = knowledge_id_list
self.exclude_document_id_list = exclude_document_id_list
2025-06-12 07:21:59 +00:00
self.application_id = application_id
2025-06-09 08:18:43 +00:00
self.chat_record_list: List[ChatRecord] = []
self.work_flow_version = work_flow_version
2025-06-12 07:21:59 +00:00
self.debug = debug
2025-06-09 08:18:43 +00:00
@staticmethod
def get_no_references_setting(knowledge_setting, model_setting):
no_references_setting = knowledge_setting.get(
'no_references_setting', {
'status': 'ai_questioning',
'value': '{question}'})
if no_references_setting.get('status') == 'ai_questioning':
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
return no_references_setting
def to_base_pipeline_manage_params(self):
knowledge_setting = self.application.knowledge_setting
model_setting = self.application.model_setting
model_id = self.application.model.id if self.application.model is not None else None
model_params_setting = None
if model_id is not None:
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
return {
'knowledge_id_list': self.knowledge_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'exclude_paragraph_id_list': [],
'top_n': knowledge_setting.get('top_n') or 3,
'similarity': knowledge_setting.get('similarity') or 0.6,
'max_paragraph_char_number': knowledge_setting.get('max_paragraph_char_number') or 5000,
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
self.application.problem_optimization_prompt) > 0 else _(
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag"),
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting and len(model_setting.get(
'prompt')) > 0 else Application.get_default_model_prompt(),
'system': model_setting.get(
'system', None),
'model_id': model_id,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'model_setting': model_setting,
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding',
'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting),
'user_id': self.application.user_id,
'application_id': self.application.id
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list, chat_user_id: str, chat_user_type, stream=True,
form_data=None):
if form_data is None:
form_data = {}
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id,
'chat_user_type': chat_user_type, 'form_data': form_data}
def append_chat_record(self, chat_record: ChatRecord):
chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""
is_save = True
# 存入缓存中
for index in range(len(self.chat_record_list)):
record = self.chat_record_list[index]
if record.id == chat_record.id:
self.chat_record_list[index] = chat_record
is_save = False
2025-06-12 07:21:59 +00:00
break
2025-06-09 08:18:43 +00:00
if is_save:
self.chat_record_list.append(chat_record)
2025-06-12 07:21:59 +00:00
if not self.debug:
if not QuerySet(Chat).filter(id=self.chat_id).exists():
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
else:
QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
2025-06-09 08:18:43 +00:00
# 插入会话记录
2025-06-12 07:21:59 +00:00
chat_record.save()
2025-06-09 08:18:43 +00:00
def set_cache(self):
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
timeout=60 * 30)
@staticmethod
def get_cache(chat_id):
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())