UnisKB/apps/setting/models_provider/base_model_provider.py

254 lines
8.2 KiB
Python
Raw Normal View History

# coding=utf-8
"""
@project: maxkb
@Author
@file base_model_provider.py
@date2023/10/31 16:19
@desc:
"""
from abc import ABC, abstractmethod
from enum import Enum
from functools import reduce
2024-07-12 06:15:42 +00:00
from typing import Dict, Iterator, Type, List
2024-07-12 06:15:42 +00:00
from pydantic.v1 import BaseModel
2024-03-22 09:56:56 +00:00
from common.exception.app_exception import AppApiException
class DownModelChunkStatus(Enum):
success = "success"
error = "error"
pulling = "pulling"
unknown = 'unknown'
class ValidCode(Enum):
valid_error = 500
model_not_fount = 404
class DownModelChunk:
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
self.details = details
self.status = status
self.digest = digest
self.progress = progress
self.index = index
def to_dict(self):
return {
"details": self.details,
"status": self.status.value,
"digest": self.digest,
"progress": self.progress,
"index": self.index
}
class IModelProvider(ABC):
2024-07-12 06:15:42 +00:00
@abstractmethod
def get_model_info_manage(self):
pass
@abstractmethod
def get_model_provide_info(self):
pass
def get_model_type_list(self):
2024-07-12 06:15:42 +00:00
return self.get_model_info_manage().get_model_type_list()
def get_model_list(self, model_type):
2024-07-12 06:15:42 +00:00
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
2024-07-15 08:26:54 +00:00
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
def get_model_credential(self, model_type, model_name):
2024-07-12 06:15:42 +00:00
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential
2024-07-12 06:15:42 +00:00
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
raise_exception=raise_exception)
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
def get_dialogue_number(self):
2024-07-12 06:15:42 +00:00
return 3
2024-03-22 09:56:56 +00:00
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
raise AppApiException(500, "当前平台不支持下载模型")
2024-07-12 06:15:42 +00:00
class MaxKBBaseModel(ABC):
@staticmethod
@abstractmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
@staticmethod
def is_cache_model():
return True
2024-07-12 06:15:42 +00:00
class BaseModelCredential(ABC):
@abstractmethod
2024-07-12 06:15:42 +00:00
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
pass
@abstractmethod
def encryption_dict(self, model_info: Dict[str, object]):
"""
:param model_info: 模型数据
:return: 加密后数据
"""
pass
2024-08-23 09:46:05 +00:00
def get_model_params_setting_form(self, model_name):
2024-08-14 10:54:32 +00:00
"""
2024-08-23 09:46:05 +00:00
模型参数设置表单
:return:
2024-08-14 10:54:32 +00:00
"""
pass
@staticmethod
def encryption(message: str):
"""
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
:param message:
:return:
"""
max_pre_len = 8
max_post_len = 4
message_len = len(message)
pre_len = int(message_len / 5 * 2)
post_len = int(message_len / 5 * 1)
pre_str = "".join([message[index] for index in
range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))])
end_str = "".join(
[message[index] for index in
range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)])
content = "***************"
return pre_str + content + end_str
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': '大语言模型'}
2024-07-12 06:15:42 +00:00
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
2024-08-27 09:46:52 +00:00
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
2024-09-05 03:28:21 +00:00
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
class ModelInfo:
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
2024-07-12 06:15:42 +00:00
model_class: Type[MaxKBBaseModel],
**keywords):
self.name = name
self.desc = desc
self.model_type = model_type.name
self.model_credential = model_credential
2024-07-12 06:15:42 +00:00
self.model_class = model_class
if keywords is not None:
for key in keywords.keys():
self.__setattr__(key, keywords.get(key))
def get_name(self):
"""
获取模型名称
:return: 模型名称
"""
return self.name
def get_desc(self):
"""
获取模型描述
:return: 模型描述
"""
return self.desc
def get_model_type(self):
return self.model_type
2024-07-12 06:15:42 +00:00
def get_model_class(self):
return self.model_class
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
2024-07-12 06:15:42 +00:00
not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
class ModelInfoManage:
def __init__(self):
self.model_dict = {}
self.model_list = []
self.default_model_list = []
self.default_model_dict = {}
def append_model_info(self, model_info: ModelInfo):
self.model_list.append(model_info)
model_type_dict = self.model_dict.get(model_info.model_type)
if model_type_dict is None:
self.model_dict[model_info.model_type] = {model_info.name: model_info}
else:
model_type_dict[model_info.name] = model_info
def append_default_model_info(self, model_info: ModelInfo):
self.default_model_list.append(model_info)
self.default_model_dict[model_info.model_type] = model_info
def get_model_list(self):
return [model.to_dict() for model in self.model_list]
2024-07-15 08:26:54 +00:00
def get_model_list_by_model_type(self, model_type):
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
2024-07-12 06:15:42 +00:00
def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
def get_model_info(self, model_type, model_name) -> ModelInfo:
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
if model_info is None:
raise AppApiException(500, '模型不支持')
return model_info
class builder:
def __init__(self):
self.modelInfoManage = ModelInfoManage()
def append_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_model_info(model_info)
return self
def append_model_info_list(self, model_info_list: List[ModelInfo]):
for model_info in model_info_list:
self.modelInfoManage.append_model_info(model_info)
return self
def append_default_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_default_model_info(model_info)
return self
def build(self):
return self.modelInfoManage
class ModelProvideInfo:
def __init__(self, provider: str, name: str, icon: str):
self.provider = provider
self.name = name
self.icon = icon
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__")], {})