108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
|
|
# coding=utf-8
|
|||
|
|
"""
|
|||
|
|
@project: maxkb
|
|||
|
|
@Author:虎
|
|||
|
|
@file: llm.py
|
|||
|
|
@date:2024/4/28 11:42
|
|||
|
|
@desc:
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from collections.abc import Iterator
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
|
|||
|
|
from langchain_community.chat_models import ChatZhipuAI
|
|||
|
|
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
|
|||
|
|
_convert_delta_to_message_chunk
|
|||
|
|
from langchain_core.callbacks import (
|
|||
|
|
CallbackManagerForLLMRun,
|
|||
|
|
)
|
|||
|
|
from langchain_core.messages import (
|
|||
|
|
AIMessageChunk,
|
|||
|
|
BaseMessage
|
|||
|
|
)
|
|||
|
|
from langchain_core.outputs import ChatGenerationChunk
|
|||
|
|
|
|||
|
|
from models_provider.base_model_provider import MaxKBBaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
|||
|
|
optional_params: dict
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def is_cache_model():
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|||
|
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
|||
|
|
zhipuai_chat = ZhipuChatModel(
|
|||
|
|
api_key=model_credential.get('api_key'),
|
|||
|
|
model=model_name,
|
|||
|
|
streaming=model_kwargs.get('streaming', False),
|
|||
|
|
optional_params=optional_params,
|
|||
|
|
**optional_params,
|
|||
|
|
)
|
|||
|
|
return zhipuai_chat
|
|||
|
|
|
|||
|
|
usage_metadata: dict = {}
|
|||
|
|
|
|||
|
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
|||
|
|
return self.usage_metadata
|
|||
|
|
|
|||
|
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|||
|
|
return self.usage_metadata.get('prompt_tokens', 0)
|
|||
|
|
|
|||
|
|
def get_num_tokens(self, text: str) -> int:
|
|||
|
|
return self.usage_metadata.get('completion_tokens', 0)
|
|||
|
|
|
|||
|
|
def _stream(
|
|||
|
|
self,
|
|||
|
|
messages: List[BaseMessage],
|
|||
|
|
stop: Optional[List[str]] = None,
|
|||
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> Iterator[ChatGenerationChunk]:
|
|||
|
|
"""Stream the chat response in chunks."""
|
|||
|
|
if self.zhipuai_api_key is None:
|
|||
|
|
raise ValueError("Did not find zhipuai_api_key.")
|
|||
|
|
if self.zhipuai_api_base is None:
|
|||
|
|
raise ValueError("Did not find zhipu_api_base.")
|
|||
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|||
|
|
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
|
|||
|
|
_truncate_params(payload)
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
|||
|
|
"Accept": "application/json",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
default_chunk_class = AIMessageChunk
|
|||
|
|
import httpx
|
|||
|
|
|
|||
|
|
with httpx.Client(headers=headers, timeout=60) as client:
|
|||
|
|
with connect_sse(
|
|||
|
|
client, "POST", self.zhipuai_api_base, json=payload
|
|||
|
|
) as event_source:
|
|||
|
|
for sse in event_source.iter_sse():
|
|||
|
|
chunk = json.loads(sse.data)
|
|||
|
|
if len(chunk["choices"]) == 0:
|
|||
|
|
continue
|
|||
|
|
choice = chunk["choices"][0]
|
|||
|
|
generation_info = {}
|
|||
|
|
if "usage" in chunk:
|
|||
|
|
generation_info = chunk["usage"]
|
|||
|
|
self.usage_metadata = generation_info
|
|||
|
|
chunk = _convert_delta_to_message_chunk(
|
|||
|
|
choice["delta"], default_chunk_class
|
|||
|
|
)
|
|||
|
|
finish_reason = choice.get("finish_reason", None)
|
|||
|
|
|
|||
|
|
chunk = ChatGenerationChunk(
|
|||
|
|
message=chunk, generation_info=generation_info
|
|||
|
|
)
|
|||
|
|
yield chunk
|
|||
|
|
if run_manager:
|
|||
|
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
|||
|
|
if finish_reason is not None:
|
|||
|
|
break
|