UnisKB/apps/setting/serializers/provider_serializers.py

272 lines
13 KiB
Python
Raw Normal View History

# coding=utf-8
"""
@project: maxkb
@Author
@file provider_serializers.py
@date2023/11/2 14:01
@desc:
"""
import json
2024-03-22 09:56:56 +00:00
import threading
import time
import uuid
from typing import Dict
from django.db.models import QuerySet
from rest_framework import serializers
from application.models import Application
from common.exception.app_exception import AppApiException
2024-03-04 02:12:18 +00:00
from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt
2024-03-22 09:56:56 +00:00
from setting.models.model_management import Model, Status
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
2024-03-22 09:56:56 +00:00
class ModelPullManage:
@staticmethod
def pull(model: Model, credential: Dict):
response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name,
credential)
down_model_chunk = {}
timestamp = time.time()
for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - timestamp > 5:
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": list(down_model_chunk.values())})
timestamp = time.time()
status = Status.ERROR
message = ""
down_model_chunk_list = list(down_model_chunk.values())
for chunk in down_model_chunk_list:
if chunk.get('status') == DownModelChunkStatus.success.value:
status = Status.SUCCESS
if chunk.get('status') == DownModelChunkStatus.error.value:
message = chunk.get("digest")
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": message},
2024-03-22 09:56:56 +00:00
status=status)
class ModelSerializer(serializers.Serializer):
class Query(serializers.Serializer):
2024-03-04 02:12:18 +00:00
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
2024-03-04 02:25:51 +00:00
name = serializers.CharField(required=False, max_length=20,
2024-03-04 02:12:18 +00:00
error_messages=ErrMessage.char("模型名称"))
2024-03-04 02:12:18 +00:00
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
2024-03-04 02:12:18 +00:00
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("基础模型"))
2024-03-04 02:12:18 +00:00
provider = serializers.CharField(required=False, error_messages=ErrMessage.char("供应商"))
def list(self, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id)
query_params = {}
if name is not None:
query_params['name__contains'] = name
if self.data.get('model_type') is not None:
query_params['model_type'] = self.data.get('model_type')
if self.data.get('model_name') is not None:
query_params['model_name'] = self.data.get('model_name')
if self.data.get('provider') is not None:
query_params['provider'] = self.data.get('provider')
2024-03-22 09:56:56 +00:00
return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
model_query_set.filter(**query_params).order_by("-create_time")]
2023-12-01 09:30:06 +00:00
class Edit(serializers.Serializer):
2024-03-04 02:12:18 +00:00
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
2023-12-01 09:30:06 +00:00
2024-03-04 02:25:51 +00:00
name = serializers.CharField(required=False, max_length=20,
2024-03-04 02:12:18 +00:00
error_messages=ErrMessage.char("模型名称"))
2023-12-01 09:30:06 +00:00
2024-03-04 02:12:18 +00:00
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
2023-12-01 09:30:06 +00:00
2024-03-04 02:12:18 +00:00
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
2023-12-01 09:30:06 +00:00
2024-03-04 02:12:18 +00:00
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
2023-12-01 09:30:06 +00:00
def is_valid(self, model=None, raise_exception=False):
super().is_valid(raise_exception=True)
filter_params = {'user_id': self.data.get('user_id')}
if 'name' in self.data and self.data.get('name') is not None:
filter_params['name'] = self.data.get('name')
if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
ModelSerializer.model_to_dict(model)
provider = model.provider
model_type = self.data.get('model_type')
model_name = self.data.get(
'model_name')
credential = self.data.get('credential')
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name)
source_model_credential = json.loads(decrypt(model.credential))
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None:
for k in source_encryption_model_credential.keys():
if credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k]
2024-03-22 09:56:56 +00:00
return credential, model_credential
2023-12-01 09:30:06 +00:00
class Create(serializers.Serializer):
2024-03-04 02:12:18 +00:00
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
2024-03-04 02:25:51 +00:00
name = serializers.CharField(required=True, max_length=20, error_messages=ErrMessage.char("模型名称"))
2024-03-04 02:12:18 +00:00
provider = serializers.CharField(required=True, error_messages=ErrMessage.char("供应商"))
2024-03-04 02:12:18 +00:00
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
2024-03-04 02:12:18 +00:00
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
2024-03-04 02:12:18 +00:00
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
name=self.data.get('name')).exists():
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
# 校验模型认证数据
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'),
self.data.get(
'model_name')).is_valid(
self.data.get('model_type'),
self.data.get('model_name'),
self.data.get('credential'),
raise_exception=True)
def insert(self, user_id, with_valid=False):
2024-03-22 09:56:56 +00:00
status = Status.SUCCESS
if with_valid:
2024-03-22 09:56:56 +00:00
try:
self.is_valid(raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
status = Status.DOWNLOAD
else:
raise e
credential = self.data.get('credential')
name = self.data.get('name')
provider = self.data.get('provider')
model_type = self.data.get('model_type')
model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential)
2024-03-22 09:56:56 +00:00
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
model.save()
2024-03-22 09:56:56 +00:00
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
2023-12-01 09:30:06 +00:00
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
@staticmethod
def model_to_dict(model: Model):
credential = json.loads(decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
2024-03-22 09:56:56 +00:00
'status': model.status,
'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict(
credential)}
class Operate(serializers.Serializer):
2024-03-04 02:12:18 +00:00
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
2024-03-04 02:12:18 +00:00
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
2023-12-01 09:30:06 +00:00
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get("id"), user_id=self.data.get("user_id")).first()
if model is None:
raise AppApiException(500, '模型不存在')
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
2023-12-01 09:30:06 +00:00
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return ModelSerializer.model_to_dict(model)
2024-03-22 09:56:56 +00:00
def one_meta(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta, }
2023-12-01 09:30:06 +00:00
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application_list = QuerySet(Application).filter(model_id=self.data.get('id')).all()
if len(application_list) > 0:
raise AppApiException(500, f"该模型关联了{len(application_list)} 个应用,无法删除该模型。")
2023-12-01 09:30:06 +00:00
QuerySet(Model).filter(id=self.data.get('id')).delete()
return True
def edit(self, instance: Dict, user_id: str, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get('id')).first()
if model is None:
raise AppApiException(500, '不存在的id')
else:
2024-03-22 09:56:56 +00:00
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
model=model)
try:
model.status = Status.SUCCESS
2024-03-22 09:56:56 +00:00
# 校验模型认证数据
model_credential.is_valid(
model.model_type,
instance.get("model_name"),
credential,
raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD
else:
raise e
2023-12-01 09:30:06 +00:00
update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, encrypt(model_credential_str))
else:
model.__setattr__(update_key, instance.get(update_key))
model.save()
2024-03-22 09:56:56 +00:00
if model.status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
2023-12-01 09:30:06 +00:00
return self.one(with_valid=False)
class ProviderSerializer(serializers.Serializer):
2024-03-04 02:12:18 +00:00
provider = serializers.CharField(required=True, error_messages=ErrMessage.char("供应商"))
2024-03-04 02:12:18 +00:00
method = serializers.CharField(required=True, error_messages=ErrMessage.char("执行函数名称"))
def exec(self, exec_params: Dict[str, object], with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
provider = self.data.get('provider')
method = self.data.get('method')
return getattr(ModelProvideConstants[provider].value, method)(exec_params)