meeting_memory/providers/openai_compatible.py

209 lines
7.6 KiB
Python

from __future__ import annotations
import json
import uuid
from typing import Any, Callable, Dict, Iterator, List, Optional
from .base import AgentProvider, AssistantTurn, StreamEvent, ToolCall
class OpenAICompatibleProvider(AgentProvider):
def __init__(
self,
*,
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
temperature: float = 0.2,
timeout: float = 120.0,
) -> None:
self.model = model
self.api_key = api_key
self.base_url = base_url
self.temperature = temperature
self.timeout = timeout
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
from openai import OpenAI
client = OpenAI(**self._client_kwargs())
request = self._build_request(messages, tools)
try:
response = client.chat.completions.create(**request)
except Exception as exc:
raise RuntimeError(_format_openai_error(exc, model=self.model, base_url=self.base_url)) from exc
message = response.choices[0].message
calls: list[ToolCall] = []
for item in message.tool_calls or []:
calls.append(
ToolCall(
id=item.id or f"call_{uuid.uuid4().hex}",
name=item.function.name,
arguments=_parse_json(item.function.arguments or "{}"),
)
)
return AssistantTurn(content=message.content or "", tool_calls=calls, raw=response)
def stream_generate(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
should_cancel: Optional[Callable[[], bool]] = None,
) -> Iterator[StreamEvent]:
from openai import OpenAI
client = OpenAI(**self._client_kwargs())
request = self._build_request(messages, tools)
request["stream"] = True
content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_buffers: dict[int, dict[str, Any]] = {}
try:
stream = client.chat.completions.create(**request)
for chunk in stream:
if should_cancel and should_cancel():
stream.close()
from core_agent.exceptions import AgentCancelled
raise AgentCancelled("Agent run cancelled by user.")
choice = chunk.choices[0] if chunk.choices else None
if choice is None:
continue
delta = choice.delta
reasoning_delta = _extract_reasoning_delta(delta)
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
yield StreamEvent(type="reasoning", delta=reasoning_delta, raw=chunk)
content_delta = getattr(delta, "content", None) or ""
if content_delta:
content_parts.append(content_delta)
yield StreamEvent(type="content", delta=content_delta, raw=chunk)
for tool_call in getattr(delta, "tool_calls", None) or []:
index = int(getattr(tool_call, "index", 0) or 0)
entry = tool_buffers.setdefault(
index,
{
"id": getattr(tool_call, "id", None) or f"call_{uuid.uuid4().hex}",
"name": "",
"arguments_parts": [],
},
)
if getattr(tool_call, "id", None):
entry["id"] = tool_call.id
function = getattr(tool_call, "function", None)
if function is not None:
name = getattr(function, "name", None)
if name:
entry["name"] = name
arguments = getattr(function, "arguments", None)
if arguments:
entry["arguments_parts"].append(arguments)
except Exception as exc:
if should_cancel and should_cancel():
from core_agent.exceptions import AgentCancelled
raise AgentCancelled("Agent run cancelled by user.") from exc
raise RuntimeError(_format_openai_error(exc, model=self.model, base_url=self.base_url)) from exc
calls: list[ToolCall] = []
for index in sorted(tool_buffers):
item = tool_buffers[index]
calls.append(
ToolCall(
id=item["id"],
name=item["name"],
arguments=_parse_json("".join(item["arguments_parts"]) or "{}"),
)
)
yield StreamEvent(
type="turn",
turn=AssistantTurn(
content="".join(content_parts),
reasoning="".join(reasoning_parts),
tool_calls=calls,
),
)
def _client_kwargs(self) -> dict[str, Any]:
kwargs: dict[str, Any] = {"timeout": self.timeout}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.base_url:
kwargs["base_url"] = self.base_url
return kwargs
def _build_request(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> dict[str, Any]:
request: dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
}
if tools:
request["tools"] = tools
request["tool_choice"] = "auto"
request["parallel_tool_calls"] = False
return request
def _parse_json(raw: str) -> dict[str, Any]:
try:
return json.loads(raw)
except json.JSONDecodeError:
return {"raw_arguments": raw}
def _extract_reasoning_delta(delta: Any) -> str:
for attr in ("reasoning", "reasoning_content"):
value = getattr(delta, attr, None)
text = _coerce_text(value)
if text:
return text
return ""
def _coerce_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
parts: list[str] = []
for item in value:
text = _coerce_text(item)
if text:
parts.append(text)
return "".join(parts)
text = getattr(value, "text", None)
if isinstance(text, str):
return text
content = getattr(value, "content", None)
if isinstance(content, str):
return content
return ""
def _format_openai_error(exc: Exception, *, model: str, base_url: Optional[str]) -> str:
status_code = getattr(exc, "status_code", None)
message = str(exc)
response = getattr(exc, "response", None)
if response is not None:
try:
payload = response.json()
inner = payload.get("error", {}).get("message") if isinstance(payload, dict) else None
if inner:
message = inner
except Exception:
pass
if status_code == 404 and "model" in message.lower():
return (
f"大模型调用失败:当前 MODEL_NAME={model!r} 在服务端不存在。\n"
f"OPENAI_BASE_URL={base_url or ''}\n"
"请把 .env 里的 MODEL_NAME 改成该服务实际暴露的模型名。"
"可以运行 `python .\\main_cli.py --list-models` 查看可用模型。"
)
return (
f"大模型调用失败:{message}\n"
f"当前 MODEL_NAME={model!r}, OPENAI_BASE_URL={base_url or ''}"
)