UnisKB/apps/models_provider/tools.py

124 lines
3.8 KiB
Python
Raw Normal View History

2025-04-17 10:01:33 +00:00
# coding=utf-8
"""
@project: MaxKB
@Author
@file tools.py
@date2024/7/22 11:18
@desc:
"""
from django.db import connection
from django.db.models import QuerySet
from common.config.embedding_config import ModelManage
from models_provider.models import Model
from django.utils.translation import gettext_lazy as _
import json
from typing import Dict
from common.utils.rsa_util import rsa_long_decrypt
from models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs):
"""
获取模型实例
@param provider: 供应商
@param model_type: 模型类型
@param model_name: 模型名称
@param credential: 认证信息
@param model_id: 模型id
@param use_local: 是否调用本地模型 只适用于本地供应商
@return: 模型实例
"""
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
model_id=model_id,
use_local=use_local,
streaming=True, **kwargs)
return model
def get_model(model, **kwargs):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential, str(model.id), **kwargs)
def get_provider(provider):
"""
获取供应商实例
@param provider: 供应商字符串
@return: 供应商实例
"""
return ModelProvideConstants[provider].value
def get_model_list(provider, model_type):
"""
获取模型列表
@param provider: 供应商字符串
@param model_type: 模型类型
@return: 模型列表
"""
return get_provider(provider).get_model_list(model_type)
def get_model_credential(provider, model_type, model_name):
"""
获取模型认证实例
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@return: 认证实例对象
"""
return get_provider(provider).get_model_credential(model_type, model_name)
def get_model_type_list(provider):
"""
获取模型类型列表
@param provider: 供应商字符串
@return: 模型类型列表
"""
return get_provider(provider).get_model_type_list()
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params,
raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@param model_credential: 模型认证数据
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params,
raise_exception)
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
# 手动关闭数据库连接
connection.close()
if model is None:
raise Exception(_('Model does not exist'))
return model
def get_model_instance_by_model_user_id(model_id, user_id, **kwargs):
"""
获取模型实例,根据模型相关数据
@param model_id: 模型id
@param user_id: 用户id
@return: 模型实例
"""
model = get_model_by_id(model_id, user_id)
return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs))