2024-08-21 06:46:11 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: MaxKB
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: model_apply_serializers.py
|
|
|
|
|
|
@date:2024/8/20 20:39
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
2024-12-26 07:29:55 +00:00
|
|
|
|
from django.db import connection
|
2024-08-21 06:46:11 +00:00
|
|
|
|
from django.db.models import QuerySet
|
2024-09-05 03:28:21 +00:00
|
|
|
|
from langchain_core.documents import Document
|
2024-08-21 06:46:11 +00:00
|
|
|
|
from rest_framework import serializers
|
|
|
|
|
|
|
|
|
|
|
|
from common.config.embedding_config import ModelManage
|
|
|
|
|
|
from common.util.field_message import ErrMessage
|
|
|
|
|
|
from setting.models import Model
|
|
|
|
|
|
from setting.models_provider import get_model
|
2025-01-13 03:15:51 +00:00
|
|
|
|
from django.utils.translation import gettext_lazy as _
|
2024-08-21 06:46:11 +00:00
|
|
|
|
|
|
|
|
|
|
def get_embedding_model(model_id):
|
|
|
|
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
2024-12-26 07:29:55 +00:00
|
|
|
|
# 手动关闭数据库连接
|
|
|
|
|
|
connection.close()
|
2024-08-21 06:46:11 +00:00
|
|
|
|
embedding_model = ModelManage.get_model(model_id,
|
|
|
|
|
|
lambda _id: get_model(model, use_local=True))
|
|
|
|
|
|
return embedding_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbedDocuments(serializers.Serializer):
|
|
|
|
|
|
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
|
|
|
|
|
|
error_messages=ErrMessage.char(
|
2025-01-13 03:15:51 +00:00
|
|
|
|
_('vector text'))),
|
|
|
|
|
|
error_messages=ErrMessage.list(_('vector text list')))
|
2024-08-21 06:46:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbedQuery(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
text = serializers.CharField(required=True, error_messages=ErrMessage.char(_('vector text')))
|
2024-08-21 06:46:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-09-05 03:28:21 +00:00
|
|
|
|
class CompressDocument(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
page_content = serializers.CharField(required=True, error_messages=ErrMessage.char(_('text')))
|
|
|
|
|
|
metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict(_('metadata')))
|
2024-09-05 03:28:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompressDocuments(serializers.Serializer):
|
|
|
|
|
|
documents = CompressDocument(required=True, many=True)
|
2025-01-13 03:15:51 +00:00
|
|
|
|
query = serializers.CharField(required=True, error_messages=ErrMessage.char(_('query')))
|
2024-09-05 03:28:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-08-21 06:46:11 +00:00
|
|
|
|
class ModelApplySerializers(serializers.Serializer):
|
2025-01-13 03:15:51 +00:00
|
|
|
|
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('model id')))
|
2024-08-21 06:46:11 +00:00
|
|
|
|
|
|
|
|
|
|
def embed_documents(self, instance, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
EmbedDocuments(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
model = get_embedding_model(self.data.get('model_id'))
|
|
|
|
|
|
return model.embed_documents(instance.getlist('texts'))
|
|
|
|
|
|
|
|
|
|
|
|
def embed_query(self, instance, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
EmbedQuery(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
|
|
|
|
|
|
model = get_embedding_model(self.data.get('model_id'))
|
|
|
|
|
|
return model.embed_query(instance.get('text'))
|
2024-09-05 03:28:21 +00:00
|
|
|
|
|
|
|
|
|
|
def compress_documents(self, instance, with_valid=True):
|
|
|
|
|
|
if with_valid:
|
|
|
|
|
|
self.is_valid(raise_exception=True)
|
|
|
|
|
|
CompressDocuments(data=instance).is_valid(raise_exception=True)
|
|
|
|
|
|
model = get_embedding_model(self.data.get('model_id'))
|
|
|
|
|
|
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
|
|
|
|
|
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
|
|
|
|
|
instance.get('documents')], instance.get('query'))]
|