UnisKB/apps/models_provider/impl/local_model_provider/model/embedding.py

65 lines
2.6 KiB
Python
Raw Normal View History

2025-04-17 10:01:33 +00:00
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 14:06
@desc:
"""
from typing import Dict, List
import requests
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
from models_provider.base_model_provider import MaxKBBaseModel
from maxkb.const import CONFIG
class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
model_id: str = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)
def embed_query(self, text: str) -> List[float]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
2025-04-17 10:01:33 +00:00
{'text': text})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
2025-04-17 10:01:33 +00:00
{'texts': texts})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
if model_kwargs.get('use_local', True):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True}
)
return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
**model_kwargs)