123 lines
5.7 KiB
Python
123 lines
5.7 KiB
Python
|
|
# coding=utf-8
|
|||
|
|
"""
|
|||
|
|
@project: maxkb
|
|||
|
|
@Author:虎
|
|||
|
|
@file: model.py
|
|||
|
|
@date:2023/11/2 13:55
|
|||
|
|
@desc:
|
|||
|
|
"""
|
|||
|
|
from drf_yasg.utils import swagger_auto_schema
|
|||
|
|
from rest_framework.decorators import action
|
|||
|
|
from rest_framework.views import APIView
|
|||
|
|
from rest_framework.views import Request
|
|||
|
|
|
|||
|
|
from common.auth import TokenAuth, has_permissions
|
|||
|
|
from common.constants.permission_constants import PermissionConstants
|
|||
|
|
from common.response import result
|
|||
|
|
from common.util.common import query_params_to_single_dict
|
|||
|
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
|||
|
|
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer
|
|||
|
|
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Model(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['POST'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="创建模型",
|
|||
|
|
operation_id="创建模型",
|
|||
|
|
request_body=ModelCreateApi.get_request_body_api()
|
|||
|
|
, tags=["模型"])
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_CREATE)
|
|||
|
|
def post(self, request: Request):
|
|||
|
|
return result.success(
|
|||
|
|
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
|||
|
|
with_valid=True))
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取模型列表",
|
|||
|
|
operation_id="获取模型列表",
|
|||
|
|
manual_parameters=ModelQueryApi.get_request_params_api()
|
|||
|
|
, tags=["模型"])
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_READ)
|
|||
|
|
def get(self, request: Request):
|
|||
|
|
return result.success(
|
|||
|
|
ModelSerializer.Query(
|
|||
|
|
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
|
|||
|
|
with_valid=True))
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Provide(APIView):
|
|||
|
|
class Exec(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['POST'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="调用供应商函数,获取表单数据",
|
|||
|
|
operation_id="调用供应商函数,获取表单数据",
|
|||
|
|
manual_parameters=ProvideApi.get_request_params_api(),
|
|||
|
|
request_body=ProvideApi.get_request_body_api()
|
|||
|
|
, tags=["模型"])
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_READ)
|
|||
|
|
def post(self, request: Request, provider: str, method: str):
|
|||
|
|
return result.success(
|
|||
|
|
ProviderSerializer(data={'provider': provider, 'method': method}).exec(request.data, with_valid=True))
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取模型供应商数据",
|
|||
|
|
operation_id="获取模型供应商列表"
|
|||
|
|
, tags=["模型"])
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_READ)
|
|||
|
|
def get(self, request: Request):
|
|||
|
|
return result.success(
|
|||
|
|
[ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in
|
|||
|
|
ModelProvideConstants.__members__])
|
|||
|
|
|
|||
|
|
class ModelTypeList(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取模型类型列表",
|
|||
|
|
operation_id="获取模型类型类型列表",
|
|||
|
|
manual_parameters=ProvideApi.ModelTypeList.get_request_params_api(),
|
|||
|
|
responses=result.get_api_array_response(ProvideApi.ModelTypeList.get_response_body_api())
|
|||
|
|
, tags=["模型"])
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_READ)
|
|||
|
|
def get(self, request: Request):
|
|||
|
|
provider = request.query_params.get('provider')
|
|||
|
|
return result.success(ModelProvideConstants[provider].value.get_model_type_list())
|
|||
|
|
|
|||
|
|
class ModelList(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取模型列表",
|
|||
|
|
operation_id="获取模型创建表单",
|
|||
|
|
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
|
|||
|
|
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
|
|||
|
|
, tags=["模型"]
|
|||
|
|
)
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_READ)
|
|||
|
|
def get(self, request: Request):
|
|||
|
|
provider = request.query_params.get('provider')
|
|||
|
|
model_type = request.query_params.get('model_type')
|
|||
|
|
|
|||
|
|
return result.success(
|
|||
|
|
ModelProvideConstants[provider].value.get_model_list(
|
|||
|
|
model_type))
|
|||
|
|
|
|||
|
|
class ModelForm(APIView):
|
|||
|
|
authentication_classes = [TokenAuth]
|
|||
|
|
|
|||
|
|
@action(methods=['GET'], detail=False)
|
|||
|
|
@swagger_auto_schema(operation_summary="获取模型创建表单",
|
|||
|
|
operation_id="获取模型创建表单",
|
|||
|
|
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
|
|||
|
|
tags=["模型"])
|
|||
|
|
@has_permissions(PermissionConstants.MODEL_READ)
|
|||
|
|
def get(self, request: Request):
|
|||
|
|
provider = request.query_params.get('provider')
|
|||
|
|
model_type = request.query_params.get('model_type')
|
|||
|
|
model_name = request.query_params.get('model_name')
|
|||
|
|
return result.success(
|
|||
|
|
ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())
|