perf: Memory optimization (#4318)
parent
1f4d6d1123
commit
a8d0729e65
|
|
@ -224,11 +224,12 @@ class LoopWorkFlowPostHandler(WorkFlowPostHandler):
|
||||||
|
|
||||||
class BaseLoopNode(ILoopNode):
|
class BaseLoopNode(ILoopNode):
|
||||||
def save_context(self, details, workflow_manage):
|
def save_context(self, details, workflow_manage):
|
||||||
self.context['result'] = details.get('result')
|
self.context['loop_context_data'] = details.get('loop_context_data')
|
||||||
|
self.context['loop_answer_data'] = details.get('loop_answer_data')
|
||||||
for key, value in details['context'].items():
|
for key, value in details['context'].items():
|
||||||
if key not in self.context:
|
if key not in self.context:
|
||||||
self.context[key] = value
|
self.context[key] = value
|
||||||
self.answer_text = str(details.get('result'))
|
self.answer_text = ""
|
||||||
|
|
||||||
def get_answer_list(self) -> List[Answer] | None:
|
def get_answer_list(self) -> List[Answer] | None:
|
||||||
result = []
|
result = []
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from urllib.parse import urljoin
|
||||||
|
|
||||||
import uuid_utils.compat as uuid
|
import uuid_utils.compat as uuid
|
||||||
from charset_normalizer import detect
|
from charset_normalizer import detect
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from common.handle.base_split_handle import BaseSplitHandle
|
from common.handle.base_split_handle import BaseSplitHandle
|
||||||
|
|
@ -39,7 +38,6 @@ class FileBufferHandle:
|
||||||
return self.buffer
|
return self.buffer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
default_split_handle = TextSplitHandle()
|
default_split_handle = TextSplitHandle()
|
||||||
split_handles = [
|
split_handles = [
|
||||||
HTMLSplitHandle(),
|
HTMLSplitHandle(),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from django.contrib import admin
|
||||||
|
|
||||||
|
# Register your models here.
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelConfig(AppConfig):
|
||||||
|
default_auto_field = 'django.db.models.BigAutoField'
|
||||||
|
name = 'local_model'
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py
|
||||||
|
@date:2023/9/25 15:04
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .model_management import *
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
# coding=utf-8
|
||||||
|
import uuid_utils.compat as uuid
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
from common.mixins.app_model_mixin import AppModelMixin
|
||||||
|
from local_model.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
class Status(models.TextChoices):
|
||||||
|
"""系统设置类型"""
|
||||||
|
SUCCESS = "SUCCESS", '成功'
|
||||||
|
|
||||||
|
ERROR = "ERROR", "失败"
|
||||||
|
|
||||||
|
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||||
|
|
||||||
|
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
|
||||||
|
|
||||||
|
|
||||||
|
class Model(AppModelMixin):
|
||||||
|
"""
|
||||||
|
模型数据
|
||||||
|
"""
|
||||||
|
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||||
|
|
||||||
|
name = models.CharField(max_length=128, verbose_name="名称", db_index=True)
|
||||||
|
|
||||||
|
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
|
||||||
|
default=Status.SUCCESS, db_index=True)
|
||||||
|
|
||||||
|
model_type = models.CharField(max_length=128, verbose_name="模型类型", db_index=True)
|
||||||
|
|
||||||
|
model_name = models.CharField(max_length=128, verbose_name="模型名称", db_index=True)
|
||||||
|
|
||||||
|
user = models.ForeignKey(User, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
||||||
|
|
||||||
|
provider = models.CharField(max_length=128, verbose_name='供应商', db_index=True)
|
||||||
|
|
||||||
|
credential = models.CharField(max_length=102400, verbose_name="模型认证信息")
|
||||||
|
|
||||||
|
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||||
|
|
||||||
|
model_params_form = models.JSONField(verbose_name="模型参数配置", default=list)
|
||||||
|
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "model"
|
||||||
|
unique_together = ['name', 'workspace_id']
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: system_management.py
|
||||||
|
@date:2024/3/19 13:47
|
||||||
|
@desc: 邮箱管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
from common.mixins.app_model_mixin import AppModelMixin
|
||||||
|
|
||||||
|
|
||||||
|
class SettingType(models.IntegerChoices):
|
||||||
|
"""系统设置类型"""
|
||||||
|
EMAIL = 0, '邮箱'
|
||||||
|
|
||||||
|
RSA = 1, "私钥秘钥"
|
||||||
|
|
||||||
|
LOG = 2, "日志清理时间"
|
||||||
|
|
||||||
|
|
||||||
|
class SystemSetting(AppModelMixin):
|
||||||
|
"""
|
||||||
|
系统设置
|
||||||
|
"""
|
||||||
|
type = models.IntegerField(primary_key=True, verbose_name='设置类型', choices=SettingType.choices,
|
||||||
|
default=SettingType.EMAIL)
|
||||||
|
|
||||||
|
meta = models.JSONField(verbose_name="配置数据", default=dict)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "system_setting"
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: user.py
|
||||||
|
@date:2025/4/14 10:20
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import uuid_utils.compat as uuid
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
from common.utils.common import password_encrypt
|
||||||
|
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||||
|
email = models.EmailField(unique=True, null=True, blank=True, verbose_name="邮箱", db_index=True)
|
||||||
|
phone = models.CharField(max_length=20, verbose_name="电话", default="", db_index=True)
|
||||||
|
nick_name = models.CharField(max_length=150, verbose_name="昵称", unique=True, db_index=True)
|
||||||
|
username = models.CharField(max_length=150, unique=True, verbose_name="用户名", db_index=True)
|
||||||
|
password = models.CharField(max_length=150, verbose_name="密码")
|
||||||
|
role = models.CharField(max_length=150, verbose_name="角色")
|
||||||
|
source = models.CharField(max_length=10, verbose_name="来源", default="LOCAL", db_index=True)
|
||||||
|
is_active = models.BooleanField(default=True, db_index=True)
|
||||||
|
language = models.CharField(max_length=10, verbose_name="语言", null=True, default=None)
|
||||||
|
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, null=True, db_index=True)
|
||||||
|
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, null=True, db_index=True)
|
||||||
|
|
||||||
|
USERNAME_FIELD = 'username'
|
||||||
|
REQUIRED_FIELDS = []
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "user"
|
||||||
|
|
||||||
|
def set_password(self, row_password):
|
||||||
|
self.password = password_encrypt(row_password)
|
||||||
|
self._password = row_password
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
# coding=utf-8
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: model_apply_serializers.py
|
||||||
|
@date:2024/8/20 20:39
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from local_model.models import Model
|
||||||
|
from local_model.serializers.rsa_util import rsa_long_decrypt
|
||||||
|
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||||
|
|
||||||
|
from common.cache.mem_cache import MemCache
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
locks = {}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManage:
|
||||||
|
cache = MemCache('model', {})
|
||||||
|
up_clear_time = time.time()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_lock(_id):
|
||||||
|
lock = locks.get(_id)
|
||||||
|
if lock is None:
|
||||||
|
with _lock:
|
||||||
|
lock = locks.get(_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = threading.Lock()
|
||||||
|
locks[_id] = lock
|
||||||
|
|
||||||
|
return lock
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model(_id, get_model):
|
||||||
|
model_instance = ModelManage.cache.get(_id)
|
||||||
|
if model_instance is None:
|
||||||
|
lock = ModelManage._get_lock(_id)
|
||||||
|
with lock:
|
||||||
|
model_instance = ModelManage.cache.get(_id)
|
||||||
|
if model_instance is None:
|
||||||
|
model_instance = get_model(_id)
|
||||||
|
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
|
||||||
|
else:
|
||||||
|
if model_instance.is_cache_model():
|
||||||
|
ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
|
||||||
|
else:
|
||||||
|
model_instance = get_model(_id)
|
||||||
|
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
|
||||||
|
ModelManage.clear_timeout_cache()
|
||||||
|
return model_instance
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_timeout_cache():
|
||||||
|
if time.time() - ModelManage.up_clear_time > 60 * 60:
|
||||||
|
threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start()
|
||||||
|
ModelManage.up_clear_time = time.time()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_key(_id):
|
||||||
|
if ModelManage.cache.has_key(_id):
|
||||||
|
ModelManage.cache.delete(_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_model(model, **kwargs):
|
||||||
|
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
|
||||||
|
return LocalModelProvider().get_model(model.model_type, model.model_name,
|
||||||
|
json.loads(
|
||||||
|
rsa_long_decrypt(model.credential)),
|
||||||
|
model_id=model.id,
|
||||||
|
streaming=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model(model_id):
|
||||||
|
model = QuerySet(Model).filter(id=model_id).first()
|
||||||
|
# 手动关闭数据库连接
|
||||||
|
connection.close()
|
||||||
|
embedding_model = ModelManage.get_model(model_id,
|
||||||
|
lambda _id: get_local_model(model, use_local=True))
|
||||||
|
return embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedDocuments(serializers.Serializer):
|
||||||
|
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
|
||||||
|
label=_('vector text')),
|
||||||
|
label=_('vector text list')),
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedQuery(serializers.Serializer):
|
||||||
|
text = serializers.CharField(required=True, label=_('vector text'))
|
||||||
|
|
||||||
|
|
||||||
|
class CompressDocument(serializers.Serializer):
|
||||||
|
page_content = serializers.CharField(required=True, label=_('text'))
|
||||||
|
metadata = serializers.DictField(required=False, label=_('metadata'))
|
||||||
|
|
||||||
|
|
||||||
|
class CompressDocuments(serializers.Serializer):
|
||||||
|
documents = CompressDocument(required=True, many=True)
|
||||||
|
query = serializers.CharField(required=True, label=_('query'))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelApplySerializers(serializers.Serializer):
|
||||||
|
model_id = serializers.UUIDField(required=True, label=_('model id'))
|
||||||
|
|
||||||
|
def embed_documents(self, instance, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
EmbedDocuments(data=instance).is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
model = get_embedding_model(self.data.get('model_id'))
|
||||||
|
return model.embed_documents(instance.getlist('texts'))
|
||||||
|
|
||||||
|
def embed_query(self, instance, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
EmbedQuery(data=instance).is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
model = get_embedding_model(self.data.get('model_id'))
|
||||||
|
return model.embed_query(instance.get('text'))
|
||||||
|
|
||||||
|
def compress_documents(self, instance, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
CompressDocuments(data=instance).is_valid(raise_exception=True)
|
||||||
|
model = get_embedding_model(self.data.get('model_id'))
|
||||||
|
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
||||||
|
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
||||||
|
instance.get('documents')], instance.get('query'))]
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: rsa_util.py
|
||||||
|
@date:2023/11/3 11:13
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||||
|
from Crypto.PublicKey import RSA
|
||||||
|
from django.core import cache
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from common.constants.cache_version import Cache_Version
|
||||||
|
from local_model.models.system_setting import SystemSetting, SettingType
|
||||||
|
|
||||||
|
lock = threading.Lock()
|
||||||
|
rsa_cache = cache.cache
|
||||||
|
cache_key = "rsa_key"
|
||||||
|
# 对密钥加密的密码
|
||||||
|
secret_code = "mac_kb_password"
|
||||||
|
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
"""
|
||||||
|
生成 私钥秘钥对
|
||||||
|
:return:{key:'公钥',value:'私钥'}
|
||||||
|
"""
|
||||||
|
# 生成一个 2048 位的密钥
|
||||||
|
key = RSA.generate(2048)
|
||||||
|
|
||||||
|
# 获取私钥
|
||||||
|
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
|
||||||
|
protection="scryptAndAES128-CBC")
|
||||||
|
return {'key': key.publickey().export_key(), 'value': encrypted_key}
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_pair():
|
||||||
|
rsa_value = rsa_cache.get(cache_key)
|
||||||
|
if rsa_value is None:
|
||||||
|
with lock:
|
||||||
|
rsa_value = rsa_cache.get(cache_key)
|
||||||
|
if rsa_value is not None:
|
||||||
|
return rsa_value
|
||||||
|
rsa_value = get_key_pair_by_sql()
|
||||||
|
version, get_key = Cache_Version.SYSTEM.value
|
||||||
|
rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version)
|
||||||
|
return rsa_value
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_pair_by_sql():
|
||||||
|
system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first()
|
||||||
|
if system_setting is None:
|
||||||
|
kv = generate()
|
||||||
|
system_setting = SystemSetting(type=SettingType.RSA.value,
|
||||||
|
meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()})
|
||||||
|
system_setting.save()
|
||||||
|
return system_setting.meta
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(msg, public_key: str | None = None):
|
||||||
|
"""
|
||||||
|
加密
|
||||||
|
:param msg: 加密数据
|
||||||
|
:param public_key: 公钥
|
||||||
|
:return: 加密后的数据
|
||||||
|
"""
|
||||||
|
if public_key is None:
|
||||||
|
public_key = get_key_pair().get('key')
|
||||||
|
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
||||||
|
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||||
|
return base64.b64encode(encrypt_msg).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt(msg, pri_key: str | None = None):
|
||||||
|
"""
|
||||||
|
解密
|
||||||
|
:param msg: 需要解密的数据
|
||||||
|
:param pri_key: 私钥
|
||||||
|
:return: 解密后数据
|
||||||
|
"""
|
||||||
|
if pri_key is None:
|
||||||
|
pri_key = get_key_pair().get('value')
|
||||||
|
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||||
|
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||||
|
return decrypt_data.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||||
|
"""
|
||||||
|
超长文本加密
|
||||||
|
|
||||||
|
:param message: 需要加密的字符串
|
||||||
|
:param public_key 公钥
|
||||||
|
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||||
|
:return: 加密后的数据
|
||||||
|
"""
|
||||||
|
# 读取公钥
|
||||||
|
if public_key is None:
|
||||||
|
public_key = get_key_pair().get('key')
|
||||||
|
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||||
|
passphrase=secret_code))
|
||||||
|
# 处理:Plaintext is too long. 分段加密
|
||||||
|
if len(message) <= length:
|
||||||
|
# 对编码的数据进行加密,并通过base64进行编码
|
||||||
|
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||||
|
else:
|
||||||
|
rsa_text = []
|
||||||
|
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||||
|
for i in range(0, len(message), length):
|
||||||
|
cont = message[i:i + length]
|
||||||
|
# 对切片后的数据进行加密,并新增到text后面
|
||||||
|
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||||
|
# 加密完进行拼接
|
||||||
|
cipher_text = b''.join(rsa_text)
|
||||||
|
# base64进行编码
|
||||||
|
result = base64.b64encode(cipher_text)
|
||||||
|
return result.decode()
|
||||||
|
|
||||||
|
|
||||||
|
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||||
|
"""
|
||||||
|
超长文本解密,默认不加密
|
||||||
|
:param message: 需要解密的数据
|
||||||
|
:param pri_key: 秘钥
|
||||||
|
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||||
|
:return: 解密后的数据
|
||||||
|
"""
|
||||||
|
if pri_key is None:
|
||||||
|
pri_key = get_key_pair().get('value')
|
||||||
|
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||||
|
base64_de = base64.b64decode(message)
|
||||||
|
res = []
|
||||||
|
for i in range(0, len(base64_de), length):
|
||||||
|
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||||
|
return b"".join(res).decode()
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
# Create your tests here.
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from django.urls import path
|
||||||
|
|
||||||
|
from . import views
|
||||||
|
|
||||||
|
app_name = "local_model"
|
||||||
|
# @formatter:off
|
||||||
|
urlpatterns = [
|
||||||
|
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
|
||||||
|
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
|
||||||
|
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
# coding=utf-8
|
||||||
|
from .model_apply import *
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: model_apply.py
|
||||||
|
@date:2024/8/20 20:38
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from urllib.request import Request
|
||||||
|
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
from common.result import result
|
||||||
|
from local_model.serializers.model_apply_serializers import ModelApplySerializers
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelApply(APIView):
|
||||||
|
class EmbedDocuments(APIView):
|
||||||
|
|
||||||
|
def post(self, request: Request, model_id):
|
||||||
|
return result.success(
|
||||||
|
ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data))
|
||||||
|
|
||||||
|
class EmbedQuery(APIView):
|
||||||
|
|
||||||
|
def post(self, request: Request, model_id):
|
||||||
|
return result.success(
|
||||||
|
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
|
||||||
|
|
||||||
|
class CompressDocuments(APIView):
|
||||||
|
|
||||||
|
def post(self, request: Request, model_id):
|
||||||
|
return result.success(
|
||||||
|
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py
|
||||||
|
@date:2025/11/5 14:50
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: auth.py
|
||||||
|
@date:2024/7/9 18:47
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
AUTH_HANDLES = [
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2025/11/5 14:53
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -0,0 +1,179 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: model.py
|
||||||
|
@date:2025/11/5 14:53
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from ...const import CONFIG, PROJECT_DIR
|
||||||
|
import os
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||||
|
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
|
# Quick-start development settings - unsuitable for production
|
||||||
|
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||||
|
|
||||||
|
# SECURITY WARNING: keep the secret key used in production secret!
|
||||||
|
SECRET_KEY = CONFIG.get("SECRET_KEY") or 'django-insecure-zm^1_^i5)3gp^&0io6zg72&z!a*d=9kf9o2%uft+27l)+t(#3e'
|
||||||
|
|
||||||
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
|
DEBUG = CONFIG.get_debug()
|
||||||
|
|
||||||
|
ALLOWED_HOSTS = ['*']
|
||||||
|
|
||||||
|
# Application definition
|
||||||
|
|
||||||
|
INSTALLED_APPS = [
|
||||||
|
'django.contrib.contenttypes',
|
||||||
|
'django.contrib.messages',
|
||||||
|
'django.contrib.staticfiles',
|
||||||
|
'rest_framework',
|
||||||
|
'local_model',
|
||||||
|
]
|
||||||
|
|
||||||
|
MIDDLEWARE = [
|
||||||
|
'django.middleware.locale.LocaleMiddleware',
|
||||||
|
'django.middleware.security.SecurityMiddleware',
|
||||||
|
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
REST_FRAMEWORK = {
|
||||||
|
'EXCEPTION_HANDLER': 'common.exception.handle_exception.handle_exception',
|
||||||
|
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
|
||||||
|
'DEFAULT_AUTHENTICATION_CLASSES': ['common.auth.authenticate.AnonymousAuthentication']
|
||||||
|
}
|
||||||
|
STATICFILES_DIRS = [(os.path.join(PROJECT_DIR, 'ui', 'dist'))]
|
||||||
|
STATIC_ROOT = os.path.join(BASE_DIR.parent, 'static')
|
||||||
|
ROOT_URLCONF = 'maxkb.urls'
|
||||||
|
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
|
||||||
|
|
||||||
|
TEMPLATES = [
|
||||||
|
{
|
||||||
|
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||||
|
'DIRS': ["apps/static/admin"],
|
||||||
|
'APP_DIRS': True,
|
||||||
|
'OPTIONS': {
|
||||||
|
'context_processors': [
|
||||||
|
'django.template.context_processors.debug',
|
||||||
|
'django.template.context_processors.request',
|
||||||
|
'django.contrib.auth.context_processors.auth',
|
||||||
|
'django.contrib.messages.context_processors.messages',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"NAME": "CHAT",
|
||||||
|
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||||
|
'DIRS': ["apps/static/chat"],
|
||||||
|
'APP_DIRS': True,
|
||||||
|
'OPTIONS': {
|
||||||
|
'context_processors': [
|
||||||
|
'django.template.context_processors.debug',
|
||||||
|
'django.template.context_processors.request',
|
||||||
|
'django.contrib.auth.context_processors.auth',
|
||||||
|
'django.contrib.messages.context_processors.messages',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"NAME": "DOC",
|
||||||
|
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||||
|
'DIRS': ["apps/static/drf_spectacular_sidecar"],
|
||||||
|
'APP_DIRS': True,
|
||||||
|
'OPTIONS': {
|
||||||
|
'context_processors': [
|
||||||
|
'django.template.context_processors.debug',
|
||||||
|
'django.template.context_processors.request',
|
||||||
|
'django.contrib.auth.context_processors.auth',
|
||||||
|
'django.contrib.messages.context_processors.messages',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
SPECTACULAR_SETTINGS = {
|
||||||
|
'TITLE': 'MaxKB API',
|
||||||
|
'DESCRIPTION': _('Intelligent customer service platform'),
|
||||||
|
'VERSION': 'v2',
|
||||||
|
'SERVE_INCLUDE_SCHEMA': False,
|
||||||
|
# OTHER SETTINGS
|
||||||
|
'SWAGGER_UI_DIST': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist', # shorthand to use the sidecar instead
|
||||||
|
'SWAGGER_UI_FAVICON_HREF': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist/favicon-32x32.png',
|
||||||
|
'REDOC_DIST': f'{CONFIG.get_admin_path()}/api-doc/redoc',
|
||||||
|
'SECURITY_DEFINITIONS': {
|
||||||
|
'Bearer': {
|
||||||
|
'type': 'apiKey',
|
||||||
|
'name': 'AUTHORIZATION',
|
||||||
|
'in': 'header',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
WSGI_APPLICATION = 'maxkb.wsgi.application'
|
||||||
|
|
||||||
|
# Database
|
||||||
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
|
||||||
|
|
||||||
|
DATABASES = {'default': CONFIG.get_db_setting()}
|
||||||
|
|
||||||
|
CACHES = CONFIG.get_cache_setting()
|
||||||
|
|
||||||
|
# Password validation
|
||||||
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators
|
||||||
|
|
||||||
|
AUTH_PASSWORD_VALIDATORS = [
|
||||||
|
{
|
||||||
|
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Internationalization
|
||||||
|
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
||||||
|
|
||||||
|
LANGUAGE_CODE = CONFIG.get("LANGUAGE_CODE")
|
||||||
|
|
||||||
|
TIME_ZONE = CONFIG.get_time_zone()
|
||||||
|
|
||||||
|
USE_I18N = True
|
||||||
|
|
||||||
|
USE_TZ = True
|
||||||
|
|
||||||
|
# 文件上传配置
|
||||||
|
DATA_UPLOAD_MAX_NUMBER_FILES = 1000
|
||||||
|
|
||||||
|
# 支持的语言
|
||||||
|
LANGUAGES = [
|
||||||
|
('en', 'English'),
|
||||||
|
('zh', '中文简体'),
|
||||||
|
('zh-hant', '中文繁体')
|
||||||
|
]
|
||||||
|
# 翻译文件路径
|
||||||
|
LOCALE_PATHS = [
|
||||||
|
os.path.join(BASE_DIR.parent, 'locales')
|
||||||
|
]
|
||||||
|
|
||||||
|
# Static files (CSS, JavaScript, Images)
|
||||||
|
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
||||||
|
|
||||||
|
STATIC_URL = 'static/'
|
||||||
|
|
||||||
|
# Default primary key field type
|
||||||
|
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
||||||
|
|
||||||
|
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||||
|
|
||||||
|
edition = 'CE'
|
||||||
|
|
||||||
|
if os.environ.get('MAXKB_REDIS_SENTINEL_SENTINELS') is not None:
|
||||||
|
DJANGO_REDIS_CONNECTION_FACTORY = "django_redis.pool.SentinelConnectionFactory"
|
||||||
|
|
@ -1,22 +1,18 @@
|
||||||
|
# coding=utf-8
|
||||||
"""
|
"""
|
||||||
Django settings for maxkb project.
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
Generated by 'django-admin startproject' using Django 4.2.4.
|
@file: web.py
|
||||||
|
@date:2025/11/5 14:53
|
||||||
For more information on this file, see
|
@desc:
|
||||||
https://docs.djangoproject.com/en/4.2/topics/settings/
|
|
||||||
|
|
||||||
For the full list of settings and their values, see
|
|
||||||
https://docs.djangoproject.com/en/4.2/ref/settings/
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ..const import CONFIG, PROJECT_DIR
|
from ...const import CONFIG, PROJECT_DIR
|
||||||
import os
|
import os
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
# Quick-start development settings - unsuitable for production
|
# Quick-start development settings - unsuitable for production
|
||||||
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB-xpack
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2025/11/5 14:45
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""
|
||||||
|
URL configuration for maxkb project.
|
||||||
|
|
||||||
|
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||||
|
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||||
|
Examples:
|
||||||
|
Function views
|
||||||
|
1. Add an import: from my_app import views
|
||||||
|
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||||
|
Class-based views
|
||||||
|
1. Add an import: from other_app.views import Home
|
||||||
|
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||||
|
Including another URLconf
|
||||||
|
1. Import the include() function: from django.urls import include, path
|
||||||
|
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||||
|
"""
|
||||||
|
|
||||||
|
from django.urls import path, include
|
||||||
|
|
||||||
|
from maxkb.const import CONFIG
|
||||||
|
|
||||||
|
admin_api_prefix = CONFIG.get_admin_path()[1:] + '/api/'
|
||||||
|
admin_ui_prefix = CONFIG.get_admin_path()
|
||||||
|
chat_api_prefix = CONFIG.get_chat_path()[1:] + '/api/'
|
||||||
|
chat_ui_prefix = CONFIG.get_chat_path()
|
||||||
|
urlpatterns = [
|
||||||
|
path(admin_api_prefix, include("local_model.urls")),
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2025/11/5 15:14
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: model.py
|
||||||
|
@date:2025/11/5 15:14
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from django.core.wsgi import get_wsgi_application
|
||||||
|
|
||||||
|
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings')
|
||||||
|
|
||||||
|
application = get_wsgi_application()
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
|
# coding=utf-8
|
||||||
"""
|
"""
|
||||||
WSGI config for maxkb project.
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
It exposes the WSGI callable as a module-level variable named ``application``.
|
@file: web.py
|
||||||
|
@date:2025/11/5 15:14
|
||||||
For more information on this file, see
|
@desc:
|
||||||
https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from django.core.wsgi import get_wsgi_application
|
from django.core.wsgi import get_wsgi_application
|
||||||
|
|
@ -15,7 +15,7 @@ from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
from models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker
|
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||||
from django.utils.translation import gettext_lazy as _, gettext
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,7 +33,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
|
model: LocalReranker = provider.get_model(model_type, model_name, model_credential)
|
||||||
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
|
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
||||||
|
|
@ -8,15 +8,16 @@
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
from common.utils.common import get_file_content
|
from common.utils.common import get_file_content
|
||||||
|
from maxkb.conf import PROJECT_DIR
|
||||||
from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
from models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
from models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
||||||
from models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
|
from models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
|
||||||
from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||||
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||||
from maxkb.conf import PROJECT_DIR
|
|
||||||
from django.utils.translation import gettext as _
|
|
||||||
|
|
||||||
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||||
LocalEmbeddingCredential(), LocalEmbedding)
|
LocalEmbeddingCredential(), LocalEmbedding)
|
||||||
|
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎
|
|
||||||
@file: embedding.py
|
|
||||||
@date:2024/7/11 14:06
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
|
||||||
|
|
||||||
from models_provider.base_model_provider import MaxKBBaseModel
|
|
||||||
from maxkb.const import CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
|
|
||||||
@staticmethod
|
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_id: str = None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.model_id = kwargs.get('model_id', None)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
|
||||||
prefix = CONFIG.get_admin_path()
|
|
||||||
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
|
|
||||||
{'text': text})
|
|
||||||
result = res.json()
|
|
||||||
if result.get('code', 500) == 200:
|
|
||||||
return result.get('data')
|
|
||||||
raise Exception(result.get('message'))
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
|
||||||
prefix = CONFIG.get_admin_path()
|
|
||||||
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
|
|
||||||
{'texts': texts})
|
|
||||||
result = res.json()
|
|
||||||
if result.get('code', 500) == 200:
|
|
||||||
return result.get('data')
|
|
||||||
raise Exception(result.get('message'))
|
|
||||||
|
|
||||||
|
|
||||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
||||||
if model_kwargs.get('use_local', True):
|
|
||||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
|
||||||
model_kwargs={'device': model_credential.get('device')},
|
|
||||||
encode_kwargs={'normalize_embeddings': True}
|
|
||||||
)
|
|
||||||
return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
|
||||||
model_kwargs={'device': model_credential.get('device')},
|
|
||||||
encode_kwargs={'normalize_embeddings': True},
|
|
||||||
**model_kwargs)
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py
|
||||||
|
@date:2025/11/5 15:24
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: model.py
|
||||||
|
@date:2025/11/5 15:26
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||||
|
model_kwargs={'device': model_credential.get('device')},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: web.py
|
||||||
|
@date:2025/11/5 15:24
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from anthropic import BaseModel
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
from maxkb.const import CONFIG
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.model_id = kwargs.get('model_id', None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||||
|
model_kwargs={'device': model_credential.get('device')},
|
||||||
|
encode_kwargs={'normalize_embeddings': True},
|
||||||
|
**model_kwargs)
|
||||||
|
|
||||||
|
model_id: str = None
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||||
|
prefix = CONFIG.get_admin_path()
|
||||||
|
res = requests.post(
|
||||||
|
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
|
||||||
|
{'text': text})
|
||||||
|
result = res.json()
|
||||||
|
if result.get('code', 500) == 200:
|
||||||
|
return result.get('data')
|
||||||
|
raise Exception(result.get('message'))
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||||
|
prefix = CONFIG.get_admin_path()
|
||||||
|
res = requests.post(
|
||||||
|
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
|
||||||
|
{'texts': texts})
|
||||||
|
result = res.json()
|
||||||
|
if result.get('code', 500) == 200:
|
||||||
|
return result.get('data')
|
||||||
|
raise Exception(result.get('message'))
|
||||||
|
|
@ -1,102 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎
|
|
||||||
@file: reranker.py.py
|
|
||||||
@date:2024/9/2 16:42
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Optional, Dict, Any, ClassVar
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
from langchain_core.callbacks import Callbacks
|
|
||||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
|
|
||||||
from models_provider.base_model_provider import MaxKBBaseModel
|
|
||||||
from maxkb.const import CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
class LocalReranker(MaxKBBaseModel):
|
|
||||||
def __init__(self, model_name, top_n=3, cache_dir=None):
|
|
||||||
super().__init__()
|
|
||||||
self.model_name = model_name
|
|
||||||
self.cache_dir = cache_dir
|
|
||||||
self.top_n = top_n
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
||||||
if model_kwargs.get('use_local', True):
|
|
||||||
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
|
|
||||||
model_kwargs={'device': model_credential.get('device', 'cpu')}
|
|
||||||
|
|
||||||
)
|
|
||||||
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
|
|
||||||
model_kwargs={'device': model_credential.get('device')},
|
|
||||||
**model_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
|
||||||
@staticmethod
|
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_id: str = None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.model_id = kwargs.get('model_id', None)
|
|
||||||
|
|
||||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
|
||||||
Sequence[Document]:
|
|
||||||
if documents is None or len(documents) == 0:
|
|
||||||
return []
|
|
||||||
prefix = CONFIG.get_admin_path()
|
|
||||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
|
||||||
res = requests.post(
|
|
||||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
|
|
||||||
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
|
|
||||||
documents], 'query': query}, headers={'Content-Type': 'application/json'})
|
|
||||||
result = res.json()
|
|
||||||
if result.get('code', 500) == 200:
|
|
||||||
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
|
|
||||||
in result.get('data')]
|
|
||||||
raise Exception(result.get('message'))
|
|
||||||
|
|
||||||
|
|
||||||
class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
|
||||||
client: Any = None
|
|
||||||
tokenizer: Any = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
cache_dir: Optional[str] = None
|
|
||||||
model_kwargs: Any = {}
|
|
||||||
|
|
||||||
def __init__(self, model_name, cache_dir=None, **model_kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model_name
|
|
||||||
self.cache_dir = cache_dir
|
|
||||||
self.model_kwargs = model_kwargs
|
|
||||||
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
|
|
||||||
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
|
|
||||||
self.client.eval()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
||||||
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)
|
|
||||||
|
|
||||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
|
||||||
Sequence[Document]:
|
|
||||||
if documents is None or len(documents) == 0:
|
|
||||||
return []
|
|
||||||
with torch.no_grad():
|
|
||||||
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
|
|
||||||
truncation=True, return_tensors='pt', max_length=512)
|
|
||||||
scores = [torch.sigmoid(s).float().item() for s in
|
|
||||||
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
|
|
||||||
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
|
|
||||||
for index
|
|
||||||
in range(len(documents))]
|
|
||||||
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
|
|
||||||
return result
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2025/11/5 15:30
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: model.py
|
||||||
|
@date:2025/11/5 15:30
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Optional, Dict, Any
|
||||||
|
|
||||||
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import Document, BaseDocumentCompressor
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||||
|
client: Any = None
|
||||||
|
tokenizer: Any = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
cache_dir: Optional[str] = None
|
||||||
|
model_kwargs: Any = {}
|
||||||
|
|
||||||
|
def __init__(self, model_name, cache_dir=None, **model_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model_name
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.model_kwargs = model_kwargs
|
||||||
|
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||||
|
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
|
||||||
|
self.client.eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return LocalReranker(model_name, cache_dir=model_credential.get('cache_dir'))
|
||||||
|
|
||||||
|
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||||
|
Sequence[Document]:
|
||||||
|
if documents is None or len(documents) == 0:
|
||||||
|
return []
|
||||||
|
import torch
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
|
||||||
|
truncation=True, return_tensors='pt', max_length=512)
|
||||||
|
scores = [torch.sigmoid(s).float().item() for s in
|
||||||
|
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
|
||||||
|
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
|
||||||
|
for index
|
||||||
|
in range(len(documents))]
|
||||||
|
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: web.py
|
||||||
|
@date:2025/11/5 15:30
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Optional, Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from anthropic import BaseModel
|
||||||
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import Document, BaseDocumentCompressor
|
||||||
|
|
||||||
|
from maxkb.const import CONFIG
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return LocalReranker(model_type=model_type, model_name=model_name, model_credential=model_credential,
|
||||||
|
**model_kwargs)
|
||||||
|
|
||||||
|
model_id: str = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
print('ssss', kwargs.get('model_id', None))
|
||||||
|
self.model_id = kwargs.get('model_id', None)
|
||||||
|
|
||||||
|
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||||
|
Sequence[Document]:
|
||||||
|
if documents is None or len(documents) == 0:
|
||||||
|
return []
|
||||||
|
prefix = CONFIG.get_admin_path()
|
||||||
|
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||||
|
res = requests.post(
|
||||||
|
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
|
||||||
|
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
|
||||||
|
documents], 'query': query}, headers={'Content-Type': 'application/json'})
|
||||||
|
result = res.json()
|
||||||
|
if result.get('code', 500) == 200:
|
||||||
|
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
|
||||||
|
in result.get('data')]
|
||||||
|
raise Exception(result.get('message'))
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from typing import Sequence, Optional, Any, Dict
|
from typing import Sequence, Optional, Dict
|
||||||
|
|
||||||
from langchain_community.embeddings import OllamaEmbeddings
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from models_provider.base_model_provider import MaxKBBaseModel
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||||
top_n: Optional[int] = Field(3, description="Number of top documents to return")
|
top_n: Optional[int] = Field(3, description="Number of top documents to return")
|
||||||
|
|
@ -22,6 +22,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||||
|
|
||||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||||
Sequence[Document]:
|
Sequence[Document]:
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
"""Rank documents based on their similarity to the query.
|
"""Rank documents based on their similarity to the query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -37,7 +38,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||||
document_embeddings = self.embed_documents(documents)
|
document_embeddings = self.embed_documents(documents)
|
||||||
# 计算相似度
|
# 计算相似度
|
||||||
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
|
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
|
||||||
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
|
ranked_docs = [(doc, _) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=doc, # 第一个值是文档内容
|
page_content=doc, # 第一个值是文档内容
|
||||||
|
|
@ -45,5 +46,3 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||||
)
|
)
|
||||||
for doc, score in ranked_docs
|
for doc, score in ranked_docs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
9
main.py
9
main.py
|
|
@ -13,7 +13,6 @@ APP_DIR = os.path.join(BASE_DIR, 'apps')
|
||||||
os.chdir(BASE_DIR)
|
os.chdir(BASE_DIR)
|
||||||
sys.path.insert(0, APP_DIR)
|
sys.path.insert(0, APP_DIR)
|
||||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "maxkb.settings")
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "maxkb.settings")
|
||||||
django.setup()
|
|
||||||
|
|
||||||
|
|
||||||
def collect_static():
|
def collect_static():
|
||||||
|
|
@ -74,7 +73,6 @@ def dev():
|
||||||
elif services.__contains__('celery'):
|
elif services.__contains__('celery'):
|
||||||
management.call_command('celery', 'celery')
|
management.call_command('celery', 'celery')
|
||||||
elif services.__contains__('local_model'):
|
elif services.__contains__('local_model'):
|
||||||
os.environ.setdefault('SERVER_NAME', 'local_model')
|
|
||||||
from maxkb.const import CONFIG
|
from maxkb.const import CONFIG
|
||||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||||
management.call_command('runserver', bind)
|
management.call_command('runserver', bind)
|
||||||
|
|
@ -108,6 +106,12 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-f', '--force', nargs="?", const=True)
|
parser.add_argument('-f', '--force', nargs="?", const=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
action = args.action
|
action = args.action
|
||||||
|
services = args.services if isinstance(args.services, list) else args.services
|
||||||
|
if services.__contains__('web'):
|
||||||
|
os.environ.setdefault('SERVER_NAME', 'web')
|
||||||
|
elif services.__contains__('local_model'):
|
||||||
|
os.environ.setdefault('SERVER_NAME', 'local_model')
|
||||||
|
django.setup()
|
||||||
if action == "upgrade_db":
|
if action == "upgrade_db":
|
||||||
perform_db_migrate()
|
perform_db_migrate()
|
||||||
elif action == "collect_static":
|
elif action == "collect_static":
|
||||||
|
|
@ -120,4 +124,3 @@ if __name__ == '__main__':
|
||||||
collect_static()
|
collect_static()
|
||||||
perform_db_migrate()
|
perform_db_migrate()
|
||||||
start_services()
|
start_services()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue