feat: add function to retrieve default parameters for embedding models
--bug=1063177 --user=刘瑞斌 【知识库】-知识库使用的模型更换维度参数值并重新向量化后,命中测试、检索报错 https://www.tapd.cn/62980211/s/1792117v3.2
parent
2de6bd2018
commit
ed19db07d1
|
|
@ -112,6 +112,21 @@ class ProblemParagraphManage:
|
|||
], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
||||
def get_embedding_model_default_params(model):
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
return {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||
|
|
@ -119,17 +134,29 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
|||
raise Exception(_('The knowledge base is inconsistent with the vector model'))
|
||||
if len(knowledge_list) == 0:
|
||||
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
||||
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
|
||||
lambda _id: get_model(knowledge_list[0].embedding_model))
|
||||
|
||||
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
|
||||
|
||||
return ModelManage.get_model(
|
||||
str(knowledge_list[0].embedding_model_id),
|
||||
lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params})
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id(knowledge_id: str):
|
||||
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
||||
|
||||
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge(knowledge):
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
||||
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||
|
||||
|
||||
def get_embedding_model_id_by_knowledge_id(knowledge_id):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
|
|||
UpdateEmbeddingDocumentIdArgs
|
||||
from common.utils.logger import maxkb_logger
|
||||
from knowledge.models import Document, TaskType, State
|
||||
from knowledge.serializers.common import drop_knowledge_index
|
||||
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model
|
||||
from ops import celery_app
|
||||
|
|
@ -26,21 +26,9 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
|
|||
try:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
default_params = get_embedding_model_default_params(model)
|
||||
|
||||
s = {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
||||
except Exception as e:
|
||||
exception_handler(e)
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Reference in New Issue