UnisKB/apps/models_provider/impl/gemini_model_provider/model/llm.py

101 lines
3.8 KiB
Python
Raw Normal View History

2025-04-17 10:01:33 +00:00
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File llm.py
@Author Brian Yang
@Date 5/13/24 7:40 AM
"""
from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast
from google.ai.generativelanguage_v1 import GenerateContentResponse
from google.ai.generativelanguage_v1beta.types import (
Tool as GoogleTool,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
2025-06-05 01:12:38 +00:00
from langchain_core.messages import BaseMessage, get_buffer_string
2025-04-17 10:01:33 +00:00
from langchain_core.outputs import ChatGenerationChunk
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result, \
_FunctionDeclarationType
from langchain_google_genai._common import (
SafetySettingDict,
)
2025-06-05 01:12:38 +00:00
from common.config.tokenizer_manage_config import TokenizerManage
2025-04-17 10:01:33 +00:00
from models_provider.base_model_provider import MaxKBBaseModel
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key'),
**optional_params
)
return gemini_chat
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
2025-06-05 01:12:38 +00:00
try:
return self.get_last_generation_info().get('input_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
2025-04-17 10:01:33 +00:00
def get_num_tokens(self, text: str) -> int:
2025-06-05 01:12:38 +00:00
try:
return self.get_last_generation_info().get('output_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
2025-04-17 10:01:33 +00:00
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
)
response: GenerateContentResponse = _chat_with_retry(
request=request,
generation_method=self.client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
)
for chunk in response:
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if gen.message:
token_usage = gen.message.usage_metadata
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
if run_manager:
run_manager.on_llm_new_token(gen.text)
yield gen