from __future__ import annotations import json import uuid from typing import Any, Callable, Dict, Iterator, List, Optional from .base import AgentProvider, AssistantTurn, StreamEvent, ToolCall class OpenAICompatibleProvider(AgentProvider): def __init__( self, *, model: str, api_key: Optional[str] = None, base_url: Optional[str] = None, temperature: float = 0.2, timeout: float = 120.0, ) -> None: self.model = model self.api_key = api_key self.base_url = base_url self.temperature = temperature self.timeout = timeout def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn: from openai import OpenAI client = OpenAI(**self._client_kwargs()) request = self._build_request(messages, tools) try: response = client.chat.completions.create(**request) except Exception as exc: raise RuntimeError(_format_openai_error(exc, model=self.model, base_url=self.base_url)) from exc message = response.choices[0].message calls: list[ToolCall] = [] for item in message.tool_calls or []: calls.append( ToolCall( id=item.id or f"call_{uuid.uuid4().hex}", name=item.function.name, arguments=_parse_json(item.function.arguments or "{}"), ) ) return AssistantTurn(content=message.content or "", tool_calls=calls, raw=response) def stream_generate( self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], should_cancel: Optional[Callable[[], bool]] = None, ) -> Iterator[StreamEvent]: from openai import OpenAI client = OpenAI(**self._client_kwargs()) request = self._build_request(messages, tools) request["stream"] = True content_parts: list[str] = [] reasoning_parts: list[str] = [] tool_buffers: dict[int, dict[str, Any]] = {} try: stream = client.chat.completions.create(**request) for chunk in stream: if should_cancel and should_cancel(): stream.close() from core_agent.exceptions import AgentCancelled raise AgentCancelled("Agent run cancelled by user.") choice = chunk.choices[0] if chunk.choices else None if choice is None: continue delta = choice.delta reasoning_delta = _extract_reasoning_delta(delta) if reasoning_delta: reasoning_parts.append(reasoning_delta) yield StreamEvent(type="reasoning", delta=reasoning_delta, raw=chunk) content_delta = getattr(delta, "content", None) or "" if content_delta: content_parts.append(content_delta) yield StreamEvent(type="content", delta=content_delta, raw=chunk) for tool_call in getattr(delta, "tool_calls", None) or []: index = int(getattr(tool_call, "index", 0) or 0) entry = tool_buffers.setdefault( index, { "id": getattr(tool_call, "id", None) or f"call_{uuid.uuid4().hex}", "name": "", "arguments_parts": [], }, ) if getattr(tool_call, "id", None): entry["id"] = tool_call.id function = getattr(tool_call, "function", None) if function is not None: name = getattr(function, "name", None) if name: entry["name"] = name arguments = getattr(function, "arguments", None) if arguments: entry["arguments_parts"].append(arguments) except Exception as exc: if should_cancel and should_cancel(): from core_agent.exceptions import AgentCancelled raise AgentCancelled("Agent run cancelled by user.") from exc raise RuntimeError(_format_openai_error(exc, model=self.model, base_url=self.base_url)) from exc calls: list[ToolCall] = [] for index in sorted(tool_buffers): item = tool_buffers[index] calls.append( ToolCall( id=item["id"], name=item["name"], arguments=_parse_json("".join(item["arguments_parts"]) or "{}"), ) ) yield StreamEvent( type="turn", turn=AssistantTurn( content="".join(content_parts), reasoning="".join(reasoning_parts), tool_calls=calls, ), ) def _client_kwargs(self) -> dict[str, Any]: kwargs: dict[str, Any] = {"timeout": self.timeout} if self.api_key: kwargs["api_key"] = self.api_key if self.base_url: kwargs["base_url"] = self.base_url return kwargs def _build_request(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> dict[str, Any]: request: dict[str, Any] = { "model": self.model, "messages": messages, "temperature": self.temperature, } if tools: request["tools"] = tools request["tool_choice"] = "auto" request["parallel_tool_calls"] = False return request def _parse_json(raw: str) -> dict[str, Any]: try: return json.loads(raw) except json.JSONDecodeError: return {"raw_arguments": raw} def _extract_reasoning_delta(delta: Any) -> str: for attr in ("reasoning", "reasoning_content"): value = getattr(delta, attr, None) text = _coerce_text(value) if text: return text return "" def _coerce_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value if isinstance(value, list): parts: list[str] = [] for item in value: text = _coerce_text(item) if text: parts.append(text) return "".join(parts) text = getattr(value, "text", None) if isinstance(text, str): return text content = getattr(value, "content", None) if isinstance(content, str): return content return "" def _format_openai_error(exc: Exception, *, model: str, base_url: Optional[str]) -> str: status_code = getattr(exc, "status_code", None) message = str(exc) response = getattr(exc, "response", None) if response is not None: try: payload = response.json() inner = payload.get("error", {}).get("message") if isinstance(payload, dict) else None if inner: message = inner except Exception: pass if status_code == 404 and "model" in message.lower(): return ( f"大模型调用失败:当前 MODEL_NAME={model!r} 在服务端不存在。\n" f"OPENAI_BASE_URL={base_url or ''}\n" "请把 .env 里的 MODEL_NAME 改成该服务实际暴露的模型名。" "可以运行 `python .\\main_cli.py --list-models` 查看可用模型。" ) return ( f"大模型调用失败:{message}\n" f"当前 MODEL_NAME={model!r}, OPENAI_BASE_URL={base_url or ''}。" )