2025-04-17 10:01:33 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: MaxKB
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: embedding.py
|
|
|
|
|
|
@date:2024/10/17 15:29
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import base64
|
|
|
|
|
|
import json
|
|
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
|
|
|
|
|
from numpy import ndarray
|
|
|
|
|
|
|
|
|
|
|
|
from models_provider.base_model_provider import MaxKBBaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings):
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
|
|
|
|
return XFEmbedding(
|
2025-06-27 08:34:22 +00:00
|
|
|
|
base_url=model_credential.get('base_url'),
|
2025-04-17 10:01:33 +00:00
|
|
|
|
spark_app_id=model_credential.get('spark_app_id'),
|
|
|
|
|
|
spark_api_key=model_credential.get('spark_api_key'),
|
|
|
|
|
|
spark_api_secret=model_credential.get('spark_api_secret')
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _parser_message(
|
|
|
|
|
|
message: str,
|
|
|
|
|
|
) -> Optional[ndarray]:
|
|
|
|
|
|
data = json.loads(message)
|
|
|
|
|
|
code = data["header"]["code"]
|
|
|
|
|
|
if code != 0:
|
|
|
|
|
|
# 这里是讯飞的QPS限制会报错,所以不建议用讯飞的向量模型
|
|
|
|
|
|
raise Exception(f"Request error: {code}, {data}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
text_base = data["payload"]["feature"]["text"]
|
|
|
|
|
|
text_data = base64.b64decode(text_base)
|
|
|
|
|
|
dt = np.dtype(np.float32)
|
|
|
|
|
|
dt = dt.newbyteorder("<")
|
|
|
|
|
|
text = np.frombuffer(text_data, dtype=dt)
|
|
|
|
|
|
if len(text) > 2560:
|
|
|
|
|
|
array = text[:2560]
|
|
|
|
|
|
else:
|
|
|
|
|
|
array = text
|
|
|
|
|
|
return array
|