feat: Local model validation using local_madel (#4330)
parent
3f6453eb3a
commit
f457588cd5
|
|
@ -74,7 +74,6 @@ class ModelManage:
|
||||||
|
|
||||||
|
|
||||||
def get_local_model(model, **kwargs):
|
def get_local_model(model, **kwargs):
|
||||||
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
|
|
||||||
return LocalModelProvider().get_model(model.model_type, model.model_name,
|
return LocalModelProvider().get_model(model.model_type, model.model_name,
|
||||||
json.loads(
|
json.loads(
|
||||||
rsa_long_decrypt(model.credential)),
|
rsa_long_decrypt(model.credential)),
|
||||||
|
|
@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer):
|
||||||
query = serializers.CharField(required=True, label=_('query'))
|
query = serializers.CharField(required=True, label=_('query'))
|
||||||
|
|
||||||
|
|
||||||
|
class ValidateModelSerializers(serializers.Serializer):
|
||||||
|
model_name = serializers.CharField(required=True, label=_('model_name'))
|
||||||
|
|
||||||
|
model_type = serializers.CharField(required=True, label=_('model_type'))
|
||||||
|
|
||||||
|
model_credential = serializers.DictField(required=True, label="credential")
|
||||||
|
|
||||||
|
def validate_model(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
|
||||||
|
self.data.get('model_credential'), model_params={},
|
||||||
|
raise_exception=True)
|
||||||
|
|
||||||
|
|
||||||
class ModelApplySerializers(serializers.Serializer):
|
class ModelApplySerializers(serializers.Serializer):
|
||||||
model_id = serializers.UUIDField(required=True, label=_('model id'))
|
model_id = serializers.UUIDField(required=True, label=_('model id'))
|
||||||
|
|
||||||
|
|
@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer):
|
||||||
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
||||||
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
||||||
instance.get('documents')], instance.get('query'))]
|
instance.get('documents')], instance.get('query'))]
|
||||||
|
|
||||||
|
def unload(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
ModelManage.delete_key(self.data.get('model_id'))
|
||||||
|
return True
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ from . import views
|
||||||
app_name = "local_model"
|
app_name = "local_model"
|
||||||
# @formatter:off
|
# @formatter:off
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
|
path('model/validate', views.LocalModelApply.Validate.as_view()),
|
||||||
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
|
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
|
||||||
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
|
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
|
||||||
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
|
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
|
||||||
|
path('model/<str:model_id>/unload', views.LocalModelApply.Unload.as_view()),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from urllib.request import Request
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
from common.result import result
|
from common.result import result
|
||||||
from local_model.serializers.model_apply_serializers import ModelApplySerializers
|
from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers
|
||||||
|
|
||||||
|
|
||||||
class LocalModelApply(APIView):
|
class LocalModelApply(APIView):
|
||||||
|
|
@ -32,3 +32,12 @@ class LocalModelApply(APIView):
|
||||||
def post(self, request: Request, model_id):
|
def post(self, request: Request, model_id):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||||
|
|
||||||
|
class Unload(APIView):
|
||||||
|
def post(self, request: Request, model_id):
|
||||||
|
return result.success(
|
||||||
|
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||||
|
|
||||||
|
class Validate(APIView):
|
||||||
|
def post(self, request: Request):
|
||||||
|
return result.success(ValidateModelSerializers(data=request.data).validate_model())
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2025/11/7 14:02
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
"""
|
"""
|
||||||
@project: MaxKB
|
@project: MaxKB
|
||||||
@Author:虎
|
@Author:虎虎
|
||||||
@file: embedding.py
|
@file: model.py.py
|
||||||
@date:2024/7/11 11:06
|
@date:2025/11/7 14:02
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: web.py
|
||||||
|
@date:2025/11/7 14:03
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from maxkb.const import CONFIG
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
|
raise_exception=False):
|
||||||
|
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||||
|
prefix = CONFIG.get_admin_path()
|
||||||
|
res = requests.post(
|
||||||
|
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
|
||||||
|
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
|
||||||
|
result = res.json()
|
||||||
|
if result.get('code', 500) == 200:
|
||||||
|
return result.get('data')
|
||||||
|
raise Exception(result.get('message'))
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return model
|
||||||
|
|
||||||
|
cache_folder = forms.TextInputField(_('Model catalog'), required=True)
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2025/11/7 14:22
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||||
|
from .model import *
|
||||||
|
else:
|
||||||
|
from .web import *
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
"""
|
"""
|
||||||
@project: MaxKB
|
@project: MaxKB
|
||||||
@Author:虎
|
@Author:虎虎
|
||||||
@file: reranker.py
|
@file: model.py
|
||||||
@date:2024/9/3 14:33
|
@date:2025/11/7 14:23
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎虎
|
||||||
|
@file: web.py
|
||||||
|
@date:2025/11/7 14:23
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from maxkb.const import CONFIG
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential
|
||||||
|
|
||||||
|
|
||||||
|
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
|
raise_exception=False):
|
||||||
|
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||||
|
prefix = CONFIG.get_admin_path()
|
||||||
|
res = requests.post(
|
||||||
|
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
|
||||||
|
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
|
||||||
|
result = res.json()
|
||||||
|
if result.get('code', 500) == 200:
|
||||||
|
return result.get('data')
|
||||||
|
raise Exception(result.get('message'))
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return model
|
||||||
|
|
||||||
|
cache_folder = forms.TextInputField(_('Model catalog'), required=True)
|
||||||
Loading…
Reference in New Issue