2023-11-16 05:16:27 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: provider_serializers.py
|
|
|
|
|
|
@date:2023/11/2 14:01
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
|
|
|
|
|
import json
|
2024-07-16 08:08:31 +00:00
|
|
|
|
import re
|
2024-03-22 09:56:56 +00:00
|
|
|
|
import threading
|
|
|
|
|
|
import time
|
2023-11-16 05:16:27 +00:00
|
|
|
|
import uuid
|
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
2024-07-16 08:08:31 +00:00
|
|
|
|
from django.core import validators
|
|
|
|
|
|
from django.db.models import QuerySet, Q
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
2024-03-05 06:29:50 +00:00
|
|
|
|
from application.models import Application
|
2024-07-24 08:12:05 +00:00
|
|
|
|
from common.config.embedding_config import ModelManage
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from common.exception.app_exception import AppApiException
|
2024-03-04 02:12:18 +00:00
|
|
|
|
from common.util.field_message import ErrMessage
|
2024-04-29 05:28:47 +00:00
|
|
|
|
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
|
2024-07-29 07:15:41 +00:00
|
|
|
|
from dataset.models import DataSet
|
2024-08-26 10:06:32 +00:00
|
|
|
|
from setting.models.model_management import Model, Status, PermissionType
|
2024-08-23 09:46:05 +00:00
|
|
|
|
from setting.models_provider import get_model, get_model_credential
|
2024-03-22 09:56:56 +00:00
|
|
|
|
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
2023-11-16 05:16:27 +00:00
|
|
|
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
2025-01-13 03:15:51 +00:00
|
|
|
|
from django.utils.translation import gettext_lazy as _
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-12-10 06:56:52 +00:00
|
|
|
|
def get_default_model_params_setting(provider, model_type, model_name):
|
|
|
|
|
|
credential = get_model_credential(provider, model_type, model_name)
|
2024-12-10 09:23:34 +00:00
|
|
|
|
setting_form = credential.get_model_params_setting_form(model_name)
|
|
|
|
|
|
if setting_form is not None:
|
|
|
|
|
|
return setting_form.to_form_list()
|
|
|
|
|
|
return []
|
2024-12-10 06:56:52 +00:00
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-22 09:56:56 +00:00
|
|
|
|
class ModelPullManage:
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def pull(model: Model, credential: Dict):
|
2024-03-26 08:55:41 +00:00
|
|
|
|
try:
|
|
|
|
|
|
response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name,
|
|
|
|
|
|
credential)
|
|
|
|
|
|
down_model_chunk = {}
|
|
|
|
|
|
timestamp = time.time()
|
|
|
|
|
|
for chunk in response:
|
|
|
|
|
|
down_model_chunk[chunk.digest] = chunk.to_dict()
|
|
|
|
|
|
if time.time() - timestamp > 5:
|
2024-07-19 08:13:44 +00:00
|
|
|
|
model_new = QuerySet(Model).filter(id=model.id).first()
|
|
|
|
|
|
if model_new.status == Status.PAUSE_DOWNLOAD:
|
|
|
|
|
|
return
|
2024-03-26 08:55:41 +00:00
|
|
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
|
|
|
|
meta={"down_model_chunk": list(down_model_chunk.values())})
|
|
|
|
|
|
timestamp = time.time()
|
|
|
|
|
|
status = Status.ERROR
|
|
|
|
|
|
message = ""
|
|
|
|
|
|
down_model_chunk_list = list(down_model_chunk.values())
|
|
|
|
|
|
for chunk in down_model_chunk_list:
|
|
|
|
|
|
if chunk.get('status') == DownModelChunkStatus.success.value:
|
|
|
|
|
|
status = Status.SUCCESS
|
|
|
|
|
|
if chunk.get('status') == DownModelChunkStatus.error.value:
|
|
|
|
|
|
message = chunk.get("digest")
|
|
|
|
|
|
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": message},
|
|
|
|
|
|
status=status)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": str(e)},
|
|
|
|
|
|
status=Status.ERROR)
|
2024-03-22 09:56:56 +00:00
|
|
|
|
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class ModelSerializer(serializers.Serializer):
|
|
|
|
|
|
class Query(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('user id')))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-08-07 03:44:32 +00:00
|
|
|
|
name = serializers.CharField(required=False, max_length=64,
|
2025-01-13 03:15:51 +00:00
|
|
|
|
error_messages=ErrMessage.char(_('model name')))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char(_('model type')))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char(_('model name')))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
provider = serializers.CharField(required=False, error_messages=ErrMessage.char(_('provider')))
|
2023-11-23 08:11:57 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char(_('permission type')))
|
2024-10-11 10:55:56 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
create_user = serializers.CharField(required=False, error_messages=ErrMessage.char(_('create user')))
|
2024-10-11 10:55:56 +00:00
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
def list(self, with_valid):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
user_id = self.data.get('user_id')
|
|
|
|
|
|
name = self.data.get('name')
|
2024-10-11 10:55:56 +00:00
|
|
|
|
create_user = self.data.get('create_user')
|
|
|
|
|
|
if create_user is not None:
|
|
|
|
|
|
# 当前用户能查看自己的模型,包括公开和私有的
|
|
|
|
|
|
if create_user == user_id:
|
|
|
|
|
|
model_query_set = QuerySet(Model).filter(Q(user_id=create_user))
|
|
|
|
|
|
# 当前用户能查看其他人的模型,只能查看公开的
|
|
|
|
|
|
else:
|
2024-12-04 02:30:45 +00:00
|
|
|
|
model_query_set = QuerySet(Model).filter(
|
|
|
|
|
|
(Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC')))
|
2024-10-11 10:55:56 +00:00
|
|
|
|
else:
|
|
|
|
|
|
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
query_params = {}
|
|
|
|
|
|
if name is not None:
|
2025-02-14 02:55:36 +00:00
|
|
|
|
query_params['name__icontains'] = name
|
2023-11-16 05:16:27 +00:00
|
|
|
|
if self.data.get('model_type') is not None:
|
|
|
|
|
|
query_params['model_type'] = self.data.get('model_type')
|
|
|
|
|
|
if self.data.get('model_name') is not None:
|
|
|
|
|
|
query_params['model_name'] = self.data.get('model_name')
|
2023-11-23 08:11:57 +00:00
|
|
|
|
if self.data.get('provider') is not None:
|
|
|
|
|
|
query_params['provider'] = self.data.get('provider')
|
2024-10-11 10:55:56 +00:00
|
|
|
|
if self.data.get('permission_type') is not None:
|
|
|
|
|
|
query_params['permission_type'] = self.data.get('permission_type')
|
|
|
|
|
|
|
2024-03-22 09:56:56 +00:00
|
|
|
|
return [
|
|
|
|
|
|
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
2024-07-18 10:34:37 +00:00
|
|
|
|
'model_name': model.model_name, 'status': model.status, 'meta': model.meta,
|
2024-12-04 02:30:45 +00:00
|
|
|
|
'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username}
|
|
|
|
|
|
for model in
|
2024-03-26 03:22:13 +00:00
|
|
|
|
model_query_set.filter(**query_params).order_by("-create_time")]
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2023-12-01 09:30:06 +00:00
|
|
|
|
class Edit(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid(_('user id')))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
2024-08-07 03:44:32 +00:00
|
|
|
|
name = serializers.CharField(required=False, max_length=64,
|
2025-01-13 03:15:51 +00:00
|
|
|
|
error_messages=ErrMessage.char(_("model name")))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char(_("model type")))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char(_("permission type")),
|
|
|
|
|
|
validators=[
|
|
|
|
|
|
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
|
|
|
|
|
|
message=_(
|
|
|
|
|
|
"permissions only supportPUBLIC|PRIVATE"),
|
|
|
|
|
|
code=500)
|
|
|
|
|
|
])
|
2024-07-16 08:08:31 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char(_("model type")))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
credential = serializers.DictField(required=False,
|
|
|
|
|
|
error_messages=ErrMessage.dict(_("certification information")))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, model=None, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
filter_params = {'user_id': self.data.get('user_id')}
|
|
|
|
|
|
if 'name' in self.data and self.data.get('name') is not None:
|
|
|
|
|
|
filter_params['name'] = self.data.get('name')
|
|
|
|
|
|
if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
|
|
|
|
|
|
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
|
|
|
|
|
|
|
|
|
|
|
|
ModelSerializer.model_to_dict(model)
|
|
|
|
|
|
|
|
|
|
|
|
provider = model.provider
|
|
|
|
|
|
model_type = self.data.get('model_type')
|
|
|
|
|
|
model_name = self.data.get(
|
|
|
|
|
|
'model_name')
|
|
|
|
|
|
credential = self.data.get('credential')
|
2024-07-12 06:15:42 +00:00
|
|
|
|
provider_handler = ModelProvideConstants[provider].value
|
2023-12-01 09:30:06 +00:00
|
|
|
|
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
|
|
|
|
|
model_name)
|
2024-04-29 05:28:47 +00:00
|
|
|
|
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
|
|
|
|
|
|
if credential is not None:
|
|
|
|
|
|
for k in source_encryption_model_credential.keys():
|
2024-12-27 10:46:09 +00:00
|
|
|
|
if k in credential and credential[k] == source_encryption_model_credential[k]:
|
2023-12-01 09:30:06 +00:00
|
|
|
|
credential[k] = source_model_credential[k]
|
2024-07-12 06:15:42 +00:00
|
|
|
|
return credential, model_credential, provider_handler
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class Create(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("user id")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
name = serializers.CharField(required=True, max_length=64, error_messages=ErrMessage.char(_("model name")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
provider = serializers.CharField(required=True, error_messages=ErrMessage.char(_("provider")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("model type")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("permission type")),
|
|
|
|
|
|
validators=[
|
|
|
|
|
|
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
|
|
|
|
|
|
message=_(
|
|
|
|
|
|
"permissions only supportPUBLIC|PRIVATE"),
|
|
|
|
|
|
code=500)
|
|
|
|
|
|
])
|
2024-07-16 08:08:31 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char(_("model name")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_params_form = serializers.ListField(required=False, default=list,
|
|
|
|
|
|
error_messages=ErrMessage.char(_("parameter configuration")))
|
2024-12-10 09:23:34 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
credential = serializers.DictField(required=True,
|
|
|
|
|
|
error_messages=ErrMessage.dict(_("certification information")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
|
|
|
|
|
|
name=self.data.get('name')).exists():
|
2025-01-13 03:15:51 +00:00
|
|
|
|
raise AppApiException(500, _('Model name【{model_name}】already exists').format(
|
|
|
|
|
|
model_name=self.data.get("name")))
|
2024-12-25 08:25:52 +00:00
|
|
|
|
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
|
2024-07-12 06:15:42 +00:00
|
|
|
|
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'),
|
|
|
|
|
|
self.data.get('model_name'),
|
2024-07-25 06:53:04 +00:00
|
|
|
|
self.data.get('credential'),
|
2024-12-25 08:25:52 +00:00
|
|
|
|
default_params,
|
2024-07-25 06:53:04 +00:00
|
|
|
|
raise_exception=True
|
2024-07-12 06:15:42 +00:00
|
|
|
|
)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def insert(self, user_id, with_valid=False):
|
2024-03-22 09:56:56 +00:00
|
|
|
|
status = Status.SUCCESS
|
2023-11-16 05:16:27 +00:00
|
|
|
|
if with_valid:
|
2024-03-22 09:56:56 +00:00
|
|
|
|
try:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
except AppApiException as e:
|
|
|
|
|
|
if e.code == ValidCode.model_not_fount:
|
|
|
|
|
|
status = Status.DOWNLOAD
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise e
|
2023-11-16 05:16:27 +00:00
|
|
|
|
credential = self.data.get('credential')
|
|
|
|
|
|
name = self.data.get('name')
|
|
|
|
|
|
provider = self.data.get('provider')
|
|
|
|
|
|
model_type = self.data.get('model_type')
|
|
|
|
|
|
model_name = self.data.get('model_name')
|
2024-07-16 08:08:31 +00:00
|
|
|
|
permission_type = self.data.get('permission_type')
|
2024-12-10 09:23:34 +00:00
|
|
|
|
model_params_form = self.data.get('model_params_form')
|
2023-11-16 05:16:27 +00:00
|
|
|
|
model_credential_str = json.dumps(credential)
|
2024-03-22 09:56:56 +00:00
|
|
|
|
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
2024-04-29 05:28:47 +00:00
|
|
|
|
credential=rsa_long_encrypt(model_credential_str),
|
2024-07-16 08:08:31 +00:00
|
|
|
|
provider=provider, model_type=model_type, model_name=model_name,
|
2024-12-10 09:23:34 +00:00
|
|
|
|
model_params_form=model_params_form,
|
2024-07-16 08:08:31 +00:00
|
|
|
|
permission_type=permission_type)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
model.save()
|
2024-03-22 09:56:56 +00:00
|
|
|
|
if status == Status.DOWNLOAD:
|
|
|
|
|
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
|
|
|
|
|
thread.start()
|
2023-12-01 09:30:06 +00:00
|
|
|
|
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def model_to_dict(model: Model):
|
2024-04-29 05:28:47 +00:00
|
|
|
|
credential = json.loads(rsa_long_decrypt(model.credential))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
|
|
|
|
|
'model_name': model.model_name,
|
2024-03-22 09:56:56 +00:00
|
|
|
|
'status': model.status,
|
|
|
|
|
|
'meta': model.meta,
|
2023-11-16 05:16:27 +00:00
|
|
|
|
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
|
|
|
|
|
model.model_name).encryption_dict(
|
2024-07-18 10:34:37 +00:00
|
|
|
|
credential),
|
|
|
|
|
|
'permission_type': model.permission_type}
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-08-26 10:06:32 +00:00
|
|
|
|
class ModelParams(serializers.Serializer):
|
|
|
|
|
|
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
|
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("user id")))
|
2024-08-26 10:06:32 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
model = QuerySet(Model).filter(id=self.data.get("id")).first()
|
|
|
|
|
|
if model is None:
|
|
|
|
|
|
raise AppApiException(500, '模型不存在')
|
|
|
|
|
|
if model.permission_type == PermissionType.PRIVATE and self.data.get('user_id') != str(model.user_id):
|
|
|
|
|
|
raise AppApiException(500, '没有权限访问到此模型')
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_params(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
model_id = self.data.get('id')
|
|
|
|
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
2024-10-15 05:19:35 +00:00
|
|
|
|
# 已经保存过的模型参数表单
|
|
|
|
|
|
return model.model_params_form
|
|
|
|
|
|
|
|
|
|
|
|
class ModelParamsForm(serializers.Serializer):
|
|
|
|
|
|
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
|
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("user id")))
|
2024-10-15 05:19:35 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
model = QuerySet(Model).filter(id=self.data.get("id")).first()
|
|
|
|
|
|
if model is None:
|
|
|
|
|
|
raise AppApiException(500, '模型不存在')
|
|
|
|
|
|
if model.permission_type == PermissionType.PRIVATE and self.data.get('user_id') != str(model.user_id):
|
|
|
|
|
|
raise AppApiException(500, '没有权限访问到此模型')
|
|
|
|
|
|
|
|
|
|
|
|
def save_model_params_form(self, model_params_form, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
if model_params_form is None:
|
|
|
|
|
|
model_params_form = []
|
|
|
|
|
|
model_id = self.data.get('id')
|
|
|
|
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
|
|
|
|
|
model.model_params_form = model_params_form
|
|
|
|
|
|
model.save()
|
|
|
|
|
|
return True
|
2024-08-26 10:06:32 +00:00
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
class Operate(serializers.Serializer):
|
2024-03-04 02:12:18 +00:00
|
|
|
|
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2025-01-13 03:15:51 +00:00
|
|
|
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("user id")))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
|
|
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
|
|
|
|
super().is_valid(raise_exception=True)
|
|
|
|
|
|
model = QuerySet(Model).filter(id=self.data.get("id"), user_id=self.data.get("user_id")).first()
|
|
|
|
|
|
if model is None:
|
|
|
|
|
|
raise AppApiException(500, '模型不存在')
|
|
|
|
|
|
|
|
|
|
|
|
def one(self, with_valid=False):
|
2023-11-16 05:16:27 +00:00
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2023-12-01 09:30:06 +00:00
|
|
|
|
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
return ModelSerializer.model_to_dict(model)
|
|
|
|
|
|
|
2024-03-22 09:56:56 +00:00
|
|
|
|
def one_meta(self, with_valid=False):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
|
|
|
|
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
|
|
|
|
|
'model_name': model.model_name,
|
|
|
|
|
|
'status': model.status,
|
2024-07-18 10:34:37 +00:00
|
|
|
|
'meta': model.meta
|
|
|
|
|
|
}
|
2024-03-22 09:56:56 +00:00
|
|
|
|
|
2023-12-01 09:30:06 +00:00
|
|
|
|
def delete(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
2024-07-29 07:15:41 +00:00
|
|
|
|
model_id = self.data.get('id')
|
|
|
|
|
|
model = Model.objects.filter(id=model_id).first()
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
# 模型不存在,直接返回或抛出异常
|
|
|
|
|
|
raise AppApiException(500, "模型不存在")
|
|
|
|
|
|
if model.model_type == 'LLM':
|
|
|
|
|
|
application_count = Application.objects.filter(model_id=model_id).count()
|
|
|
|
|
|
if application_count > 0:
|
|
|
|
|
|
raise AppApiException(500, f"该模型关联了{application_count} 个应用,无法删除该模型。")
|
|
|
|
|
|
elif model.model_type == 'EMBEDDING':
|
|
|
|
|
|
dataset_count = DataSet.objects.filter(embedding_mode_id=model_id).count()
|
|
|
|
|
|
if dataset_count > 0:
|
|
|
|
|
|
raise AppApiException(500, f"该模型关联了{dataset_count} 个知识库,无法删除该模型。")
|
2024-09-23 06:20:40 +00:00
|
|
|
|
elif model.model_type == 'TTS':
|
|
|
|
|
|
dataset_count = Application.objects.filter(tts_model_id=model_id).count()
|
|
|
|
|
|
if dataset_count > 0:
|
|
|
|
|
|
raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
|
|
|
|
|
|
elif model.model_type == 'STT':
|
|
|
|
|
|
dataset_count = Application.objects.filter(stt_model_id=model_id).count()
|
|
|
|
|
|
if dataset_count > 0:
|
|
|
|
|
|
raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
|
2024-07-29 07:15:41 +00:00
|
|
|
|
model.delete()
|
2023-12-01 09:30:06 +00:00
|
|
|
|
return True
|
|
|
|
|
|
|
2024-07-19 08:13:44 +00:00
|
|
|
|
def pause_download(self, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
2023-12-01 09:30:06 +00:00
|
|
|
|
def edit(self, instance: Dict, user_id: str, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
model = QuerySet(Model).filter(id=self.data.get('id')).first()
|
|
|
|
|
|
|
|
|
|
|
|
if model is None:
|
|
|
|
|
|
raise AppApiException(500, '不存在的id')
|
|
|
|
|
|
else:
|
2024-07-12 06:15:42 +00:00
|
|
|
|
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
|
|
|
|
|
data={**instance, 'user_id': user_id}).is_valid(
|
2024-03-22 09:56:56 +00:00
|
|
|
|
model=model)
|
|
|
|
|
|
try:
|
2024-03-22 15:32:48 +00:00
|
|
|
|
model.status = Status.SUCCESS
|
2024-12-25 08:25:52 +00:00
|
|
|
|
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
|
2024-03-22 09:56:56 +00:00
|
|
|
|
# 校验模型认证数据
|
2024-07-12 06:15:42 +00:00
|
|
|
|
provider_handler.is_valid_credential(model.model_type,
|
|
|
|
|
|
instance.get("model_name"),
|
|
|
|
|
|
credential,
|
2024-12-25 08:25:52 +00:00
|
|
|
|
default_params,
|
2024-07-12 06:15:42 +00:00
|
|
|
|
raise_exception=True)
|
|
|
|
|
|
|
2024-03-22 09:56:56 +00:00
|
|
|
|
except AppApiException as e:
|
|
|
|
|
|
if e.code == ValidCode.model_not_fount:
|
|
|
|
|
|
model.status = Status.DOWNLOAD
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise e
|
2024-07-16 08:08:31 +00:00
|
|
|
|
update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
|
2023-12-01 09:30:06 +00:00
|
|
|
|
for update_key in update_keys:
|
|
|
|
|
|
if update_key in instance and instance.get(update_key) is not None:
|
|
|
|
|
|
if update_key == 'credential':
|
|
|
|
|
|
model_credential_str = json.dumps(credential)
|
2024-04-29 05:28:47 +00:00
|
|
|
|
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
else:
|
|
|
|
|
|
model.__setattr__(update_key, instance.get(update_key))
|
2024-07-24 08:12:05 +00:00
|
|
|
|
# 修改模型时候删除缓存
|
|
|
|
|
|
ModelManage.delete_key(str(model.id))
|
2023-12-01 09:30:06 +00:00
|
|
|
|
model.save()
|
2024-03-22 09:56:56 +00:00
|
|
|
|
if model.status == Status.DOWNLOAD:
|
|
|
|
|
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
|
|
|
|
|
thread.start()
|
2023-12-01 09:30:06 +00:00
|
|
|
|
return self.one(with_valid=False)
|
|
|
|
|
|
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
class ProviderSerializer(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
provider = serializers.CharField(required=True, error_messages=ErrMessage.char(_("provider")))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
2024-03-04 02:12:18 +00:00
|
|
|
|
method = serializers.CharField(required=True, error_messages=ErrMessage.char("执行函数名称"))
|
2023-11-16 05:16:27 +00:00
|
|
|
|
|
|
|
|
|
|
def exec(self, exec_params: Dict[str, object], with_valid=False):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
provider = self.data.get('provider')
|
|
|
|
|
|
method = self.data.get('method')
|
|
|
|
|
|
return getattr(ModelProvideConstants[provider].value, method)(exec_params)
|