UnisKB/apps/models_provider/impl/xf_model_provider/model/llm.py

79 lines
2.8 KiB
Python
Raw Normal View History

2025-04-17 10:01:33 +00:00
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/04/19 15:55
@desc:
"""
from typing import List, Optional, Any, Iterator, Dict
from langchain_community.chat_models.sparkllm import \
ChatSparkLLM, convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from models_provider.base_model_provider import MaxKBBaseModel
class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
spark_llm_domain=model_name,
streaming=model_kwargs.get('streaming', False),
**optional_params
)
usage_metadata: dict = {}
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.usage_metadata.get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
return self.usage_metadata.get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
default_chunk_class = AIMessageChunk
self.client.arun(
[convert_message_to_dict(m) for m in messages],
self.spark_user_id,
self.model_kwargs,
True,
)
for content in self.client.subscribe(timeout=self.request_timeout):
if "data" in content:
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
elif "usage" in content:
generation_info = content["usage"]
self.usage_metadata = generation_info
continue
else:
continue
if cg_chunk is not None:
if run_manager:
run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk)
yield cg_chunk