199 lines
11 KiB
Python
199 lines
11 KiB
Python
|
|
# coding=utf-8
|
|||
|
|
"""
|
|||
|
|
@project: maxkb
|
|||
|
|
@Author:虎
|
|||
|
|
@file: chat_views.py
|
|||
|
|
@date:2023/11/14 9:53
|
|||
|
|
@desc:
|
|||
|
|
"""
|
|||
|
|
from drf_yasg.utils import swagger_auto_schema
|
|||
|
|
from rest_framework.decorators import action
|
|||
|
|
from rest_framework.request import Request
|
|||
|
|
from rest_framework.views import APIView
|
|||
|
|
|
|||
|
|
from application.serializers.chat_message_serializers import ChatMessageSerializer
|
|||
|
|
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
|
|||
|
|
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi
|
|||
|
|
from common.auth import TokenAuth, has_permissions
|
|||
|
|
from common.constants.permission_constants import Permission, Group, Operate, \
|
|||
|
|
RoleConstants, ViewPermission, CompareConstants
|
|||
|
|
from common.exception.app_exception import AppAuthenticationFailed
|
|||
|
|
from common.response import result
|
|||
|
|
from common.util.common import query_params_to_single_dict
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatView(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
class Open(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取会话id,根据应用id",
|
|||
|
|
operation_id="获取会话id,根据应用id",
|
|||
|
|
manual_parameters=ChatApi.OpenChat.get_request_params_api(),
|
|||
|
|
tags=["应用/会话"])
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
|
|||
|
|
RoleConstants.APPLICATION_KEY],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))],
|
|||
|
|
compare=CompareConstants.AND)
|
|||
|
|
)
|
|||
|
|
def get(self, request: Request, application_id: str):
|
|||
|
|
return result.success(ChatSerializers.OpenChat(
|
|||
|
|
data={'user_id': request.user.id, 'application_id': application_id}).open())
|
|||
|
|
|
|||
|
|
class OpenTemp(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['POST'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取会话id(根据模型id,数据集列表,是否多轮会话)",
|
|||
|
|
operation_id="获取会话id",
|
|||
|
|
request_body=ChatApi.OpenTempChat.get_request_body_api(),
|
|||
|
|
tags=["应用/会话"])
|
|||
|
|
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
|
|||
|
|
def post(self, request: Request):
|
|||
|
|
return result.success(ChatSerializers.OpenTempChat(
|
|||
|
|
data={**request.data, 'user_id': request.user.id}).open())
|
|||
|
|
|
|||
|
|
class Message(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['POST'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="对话",
|
|||
|
|
operation_id="对话",
|
|||
|
|
request_body=ChatApi.get_request_body_api(),
|
|||
|
|
tags=["应用/会话"])
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
|||
|
|
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))])
|
|||
|
|
)
|
|||
|
|
def post(self, request: Request, chat_id: str):
|
|||
|
|
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取对话列表",
|
|||
|
|
operation_id="获取对话列表",
|
|||
|
|
manual_parameters=ChatApi.get_request_params_api(),
|
|||
|
|
tags=["应用/对话日志"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))])
|
|||
|
|
)
|
|||
|
|
def get(self, request: Request, application_id: str):
|
|||
|
|
return result.success(ChatSerializers.Query(
|
|||
|
|
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
|
|||
|
|
'user_id': request.user.id}).list())
|
|||
|
|
|
|||
|
|
class Page(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="分页获取对话列表",
|
|||
|
|
operation_id="分页获取对话列表",
|
|||
|
|
manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()),
|
|||
|
|
tags=["应用/对话日志"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))])
|
|||
|
|
)
|
|||
|
|
def get(self, request: Request, application_id: str, current_page: int, page_size: int):
|
|||
|
|
return result.success(ChatSerializers.Query(
|
|||
|
|
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
|
|||
|
|
'user_id': request.user.id}).page(current_page=current_page,
|
|||
|
|
page_size=page_size))
|
|||
|
|
|
|||
|
|
class ChatRecord(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
|||
|
|
operation_id="获取对话记录列表",
|
|||
|
|
manual_parameters=ChatRecordApi.get_request_params_api(),
|
|||
|
|
tags=["应用/对话日志"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))])
|
|||
|
|
)
|
|||
|
|
def get(self, request: Request, application_id: str, chat_id: str):
|
|||
|
|
return result.success(ChatRecordSerializer.Query(
|
|||
|
|
data={'application_id': application_id,
|
|||
|
|
'chat_id': chat_id}).list())
|
|||
|
|
|
|||
|
|
class Page(APIView):
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
|||
|
|
operation_id="获取对话记录列表",
|
|||
|
|
manual_parameters=result.get_page_request_params(
|
|||
|
|
ChatRecordApi.get_request_params_api()),
|
|||
|
|
tags=["应用/对话日志"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))])
|
|||
|
|
)
|
|||
|
|
def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int):
|
|||
|
|
return result.success(ChatRecordSerializer.Query(
|
|||
|
|
data={'application_id': application_id,
|
|||
|
|
'chat_id': chat_id}).page(current_page, page_size))
|
|||
|
|
|
|||
|
|
class Vote(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['PUT'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="点赞,点踩",
|
|||
|
|
operation_id="点赞,点踩",
|
|||
|
|
manual_parameters=VoteApi.get_request_params_api(),
|
|||
|
|
request_body=VoteApi.get_request_body_api(),
|
|||
|
|
responses=result.get_default_response(),
|
|||
|
|
tags=["应用/会话"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
|||
|
|
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))])
|
|||
|
|
)
|
|||
|
|
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
|
|||
|
|
return result.success(ChatRecordSerializer.Vote(
|
|||
|
|
data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id,
|
|||
|
|
'chat_record_id': chat_record_id}).vote())
|
|||
|
|
|
|||
|
|
class Improve(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['PUT'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="标注",
|
|||
|
|
operation_id="标注",
|
|||
|
|
manual_parameters=ImproveApi.get_request_params_api(),
|
|||
|
|
request_body=ImproveApi.get_request_body_api(),
|
|||
|
|
responses=result.get_default_response(),
|
|||
|
|
tags=["应用/对话日志/标注"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
|||
|
|
dynamic_tag=keywords.get('application_id'))],
|
|||
|
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
|
|||
|
|
[lambda r, keywords: Permission(group=Group.APPLICATION,
|
|||
|
|
operate=Operate.MANAGE,
|
|||
|
|
dynamic_tag=keywords.get(
|
|||
|
|
'dataset_id'))],
|
|||
|
|
)
|
|||
|
|
))
|
|||
|
|
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, dataset_id: str,
|
|||
|
|
document_id: str):
|
|||
|
|
return result.success(ChatRecordSerializer.Improve(
|
|||
|
|
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
|
|||
|
|
'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data))
|