diff --git a/apps/models_provider/impl/vllm_model_provider/model/llm.py b/apps/models_provider/impl/vllm_model_provider/model/llm.py index 6304c9150..8d9905c36 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/models_provider/impl/vllm_model_provider/model/llm.py @@ -3,7 +3,7 @@ from typing import Dict, List from urllib.parse import urlparse, ParseResult -from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk from common.config.tokenizer_manage_config import TokenizerManage from models_provider.base_model_provider import MaxKBBaseModel @@ -48,3 +48,18 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI): tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) return self.get_last_generation_info().get('output_tokens', 0) + + def stream(self, input, config=None, *, stop=None, **kwargs): + has_content = False + for chunk in super().stream(input, config=config, stop=stop, **kwargs): + content = getattr(chunk, 'content', '') or '' + reasoning_content = (getattr(chunk, 'additional_kwargs', {}) or {}).get('reasoning_content', '') + if content or reasoning_content: + has_content = True + yield chunk + if not has_content: + result = self.invoke(input, config=config, stop=stop, **kwargs) + yield AIMessageChunk( + content=getattr(result, 'content', '') or '', + additional_kwargs=getattr(result, 'additional_kwargs', {}) or {} + )