UnisKB/apps/models_provider/impl/wenxin_model_provider/model/llm.py

105 lines
4.1 KiB
Python
Raw Normal View History

2025-04-17 10:01:33 +00:00
# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2023/11/10 17:45
@desc:
"""
from typing import List, Dict, Optional, Any, Iterator
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGenerationChunk
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
2025-04-17 10:01:33 +00:00
class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
2025-04-17 10:01:33 +00:00
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModelQianfan(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False),
init_kwargs=optional_params)
2025-04-17 10:01:33 +00:00
usage_metadata: dict = {}
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.usage_metadata.get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
return self.usage_metadata.get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs = {**self.init_kwargs, **kwargs}
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
params["stream"] = True
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
if msg.content == "" or res.get("body").get("is_end"):
token_usage = res.get("body").get("usage")
self.usage_metadata = token_usage
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk( # type: ignore[call-arg]
content=msg.content,
role="assistant",
additional_kwargs=additional_kwargs,
),
generation_info=msg.additional_kwargs,
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModelOpenai(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
extra_body=optional_params
)
class QianfanChatModel(MaxKBBaseModel):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_version = model_credential.get('api_version', 'v1')
if api_version == "v1":
return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
elif api_version == "v2":
return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)