UnisKB/apps/knowledge/vector/pg_vector.py

243 lines
10 KiB
Python
Raw Normal View History

# coding=utf-8
"""
@project: maxkb
@Author
@file pg_vector.py
@date2023/10/19 15:28
@desc:
"""
import json
import os
from abc import ABC, abstractmethod
from typing import Dict, List
import uuid_utils.compat as uuid
from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.utils.common import get_file_content
from common.utils.ts_vecto_util import to_ts_vector, to_query
from knowledge.models import Embedding, SearchMode, SourceType
from knowledge.vector.base_vector import BaseVectorStore
from maxkb.conf import PROJECT_DIR
class PGVector(BaseVectorStore):
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
if len(source_ids) == 0:
return
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
def vector_create(self):
return True
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
source_id: str,
is_active: bool,
embedding: Embeddings):
text_embedding = [float(x) for x in embedding.embed_query(text)]
embedding = Embedding(
id=uuid.uuid7(),
knowledge_id=knowledge_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
search_vector=to_ts_vector(text)
)
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [
Embedding(
id=uuid.uuid7(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
knowledge_id=text_list[index].get('knowledge_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=[float(x) for x in embeddings[index]],
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))
) for index in range(0, len(texts))]
if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
exclude_dict = {}
embedding_query = embedding.embed_query(query_text)
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=True)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
query_set = query_set.exclude(**exclude_dict)
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
document_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
exclude_dict = {}
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
if document_id_list is not None and len(document_id_list) > 0:
query_set = query_set.filter(document_id__in=document_id_list)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict)
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)
def delete_by_knowledge_id(self, knowledge_id: str):
QuerySet(Embedding).filter(knowledge_id=knowledge_id).delete()
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list).delete()
def delete_by_document_id(self, document_id: str):
QuerySet(Embedding).filter(document_id=document_id).delete()
return True
def delete_by_document_id_list(self, document_id_list: List[str]):
if len(document_id_list) == 0:
return True
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
def delete_by_source_id(self, source_id: str, source_type: str):
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
return True
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()
class ISearch(ABC):
@abstractmethod
def support(self, search_mode: SearchMode):
pass
@abstractmethod
def handle(self, query_set, query_text, query_embedding, top_number: int,
similarity: float, search_mode: SearchMode):
pass
class EmbeddingSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
'embedding_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql, [
len(query_embedding),
json.dumps(query_embedding),
*exec_params,
similarity,
top_number
])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.embedding.value
class KeywordsSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
'keywords_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql, [
to_query(query_text),
*exec_params,
similarity,
top_number
])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.keywords.value
class BlendSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql',
'blend_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql, [
len(query_embedding),
json.dumps(query_embedding),
to_query(query_text),
*exec_params, similarity,
top_number
])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.blend.value
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]