390 lines
18 KiB
Python
390 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
import json
|
|
from pathlib import Path
|
|
import re
|
|
import traceback
|
|
from typing import Any, Dict, Iterator, List, Optional
|
|
|
|
from core_agent.compression import RollingContextCompressor
|
|
from prompts import DEFAULT_SYSTEM_PROMPT, build_system_prompt
|
|
from providers.base import AgentProvider, AssistantTurn
|
|
from skills import SkillStore
|
|
from tools.default_tools import build_default_registry
|
|
from tools.registry import ToolContext, ToolRegistry
|
|
from tools.tool_trace import ToolTraceStore, ToolTraceTurn
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ChatEvent:
|
|
type: str
|
|
delta: str = ""
|
|
tool_name: str = ""
|
|
tool_args: Dict[str, Any] = field(default_factory=dict)
|
|
tool_result: str = ""
|
|
final_response: str = ""
|
|
turn_content: str = ""
|
|
turn_reasoning: str = ""
|
|
iteration: int = 0
|
|
max_iterations: int = 0
|
|
final_reason: str = ""
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AgentRunResult:
|
|
final_response: str
|
|
messages: List[Dict[str, Any]]
|
|
session: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
from core_agent.exceptions import AgentCancelled
|
|
|
|
|
|
class ConversationSession:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
provider: AgentProvider,
|
|
workspace: str | Path,
|
|
session_id: str = "default",
|
|
skill_dirs: Optional[List[str | Path]] = None,
|
|
tool_registry: Optional[ToolRegistry] = None,
|
|
system_prompt: Optional[str] = None,
|
|
max_iterations: int = 12,
|
|
max_history_turns: int = 20,
|
|
max_persisted_turns: int = 20,
|
|
active_skills: Optional[List[str]] = None,
|
|
) -> None:
|
|
self.provider = provider
|
|
self.workspace = Path(workspace).resolve()
|
|
self.session_id = _sanitize_session_id(session_id)
|
|
self.skill_store = SkillStore(skill_dirs or [self.workspace / "skills"])
|
|
self.tool_registry = tool_registry or build_default_registry(self.workspace / "data")
|
|
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
self.max_iterations = max_iterations
|
|
self.max_history_turns = max_history_turns
|
|
self.max_persisted_turns = max_persisted_turns
|
|
self.max_tool_result_chars_for_model = 12000
|
|
self.max_tool_string_field_chars = 1500
|
|
self.max_tool_list_items = 8
|
|
self.max_tool_dict_items = 20
|
|
self.session_dir = self.workspace / "agent_memory" / "sessions" / self.session_id
|
|
self.history_path = self.session_dir / "recent_history.json"
|
|
self.tool_trace_store = ToolTraceStore(self.session_dir / "tool_trace.json", max_turns=max_history_turns)
|
|
self.context_compressor = RollingContextCompressor()
|
|
self.session_state: Dict[str, Any] = {
|
|
"workspace": str(self.workspace),
|
|
"active_skills": list(active_skills or []),
|
|
"session_id": self.session_id,
|
|
"_cancel_requested": False,
|
|
}
|
|
self.history: List[Dict[str, Any]] = self._load_recent_history()
|
|
self.last_messages: List[Dict[str, Any]] = []
|
|
|
|
def ask(self, user_message: str) -> AgentRunResult:
|
|
final: ChatEvent | None = None
|
|
for event in self.stream_ask(user_message):
|
|
final = event
|
|
if final is None or final.type != "final":
|
|
raise RuntimeError("Conversation ended without a final response")
|
|
return AgentRunResult(final_response=final.final_response, messages=list(self.last_messages), session=dict(self.session_state))
|
|
|
|
def request_cancel(self) -> None:
|
|
self.session_state["_cancel_requested"] = True
|
|
|
|
def clear_cancel(self) -> None:
|
|
self.session_state["_cancel_requested"] = False
|
|
|
|
def cancellation_requested(self) -> bool:
|
|
return bool(self.session_state.get("_cancel_requested"))
|
|
|
|
def _raise_if_cancelled(self) -> None:
|
|
if self.cancellation_requested():
|
|
raise AgentCancelled("Agent run cancelled by user.")
|
|
|
|
def stream_ask(self, user_message: str) -> Iterator[ChatEvent]:
|
|
self.clear_cancel()
|
|
messages = self.build_messages(user_message)
|
|
self.last_messages = list(messages)
|
|
current_trace_turn = self.tool_trace_store.begin_turn(user_message)
|
|
repeat_counter: dict[str, int] = {}
|
|
_REPEAT_LIMIT = 3
|
|
tool_name_counter: dict[str, int] = {}
|
|
_TOOL_NAME_LIMIT = 6
|
|
|
|
for iteration in range(1, self.max_iterations + 1):
|
|
self._raise_if_cancelled()
|
|
yield ChatEvent(type="round_start", iteration=iteration, max_iterations=self.max_iterations)
|
|
try:
|
|
turn: AssistantTurn | None = None
|
|
for stream_event in self.provider.stream_generate(messages, self.tool_registry.definitions(), should_cancel=self.cancellation_requested):
|
|
self._raise_if_cancelled()
|
|
if stream_event.type == "reasoning" and stream_event.delta:
|
|
yield ChatEvent(
|
|
type="reasoning_delta",
|
|
delta=stream_event.delta,
|
|
iteration=iteration,
|
|
max_iterations=self.max_iterations,
|
|
)
|
|
elif stream_event.type == "content" and stream_event.delta:
|
|
yield ChatEvent(
|
|
type="content_delta",
|
|
delta=stream_event.delta,
|
|
iteration=iteration,
|
|
max_iterations=self.max_iterations,
|
|
)
|
|
elif stream_event.type == "turn":
|
|
turn = stream_event.turn
|
|
if turn is None:
|
|
raise RuntimeError("Provider stream ended without a final turn")
|
|
self._raise_if_cancelled()
|
|
except AgentCancelled as exc:
|
|
final = str(exc)
|
|
self.last_messages = list(messages)
|
|
self.tool_trace_store.finish_turn(current_trace_turn, final)
|
|
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="cancelled")
|
|
return
|
|
except Exception as exc:
|
|
final = str(exc)
|
|
self.last_messages = list(messages)
|
|
self.tool_trace_store.finish_turn(current_trace_turn, final)
|
|
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="provider_error")
|
|
return
|
|
|
|
yield ChatEvent(
|
|
type="assistant_turn",
|
|
turn_content=turn.content or "",
|
|
turn_reasoning=turn.reasoning or "",
|
|
iteration=iteration,
|
|
max_iterations=self.max_iterations,
|
|
)
|
|
messages.append(_assistant_message(turn))
|
|
|
|
if not turn.tool_calls:
|
|
repeat_counter.clear()
|
|
tool_name_counter.clear()
|
|
final = turn.content or ""
|
|
self._finalize_turn(user_message, final, current_trace_turn)
|
|
self.last_messages = list(messages)
|
|
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="completed")
|
|
return
|
|
|
|
ctx = ToolContext(workspace=self.workspace, session=self.session_state, tool_trace_store=self.tool_trace_store)
|
|
repeated_loop_detected = False
|
|
repeated_loop_message = ""
|
|
for call in turn.tool_calls:
|
|
self._raise_if_cancelled()
|
|
call_key = json.dumps({"name": call.name, "args": call.arguments}, sort_keys=True, ensure_ascii=False)
|
|
repeat_counter[call_key] = repeat_counter.get(call_key, 0) + 1
|
|
if repeat_counter[call_key] >= _REPEAT_LIMIT:
|
|
repeated_loop_detected = True
|
|
repeated_loop_message = "检测到重复工具调用(相同参数连续调用 {0} 次),已自动终止循环。请检查工具参数或换用更大参数的模型。".format(_REPEAT_LIMIT)
|
|
break
|
|
tool_name_counter[call.name] = tool_name_counter.get(call.name, 0) + 1
|
|
if tool_name_counter[call.name] >= _TOOL_NAME_LIMIT:
|
|
repeated_loop_detected = True
|
|
repeated_loop_message = "检测到同一工具「{0}」被连续调用 {1} 次(参数略有不同但未产生有效进展),已自动终止循环。建议换用更大参数的模型。".format(call.name, _TOOL_NAME_LIMIT)
|
|
break
|
|
yield ChatEvent(type="tool_call", tool_name=call.name, tool_args=call.arguments, iteration=iteration, max_iterations=self.max_iterations)
|
|
try:
|
|
result = self.tool_registry.execute(call.name, call.arguments, ctx)
|
|
except AgentCancelled as exc:
|
|
final = str(exc)
|
|
self.last_messages = list(messages)
|
|
self.tool_trace_store.finish_turn(current_trace_turn, final)
|
|
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="cancelled")
|
|
return
|
|
except Exception as exc:
|
|
result = _tool_error_result(call.name, call.arguments, exc)
|
|
self.tool_trace_store.record_tool_call(
|
|
current_trace_turn,
|
|
tool_name=call.name,
|
|
arguments=call.arguments,
|
|
thought=turn.reasoning or "",
|
|
result=result,
|
|
tool_call_id=call.id,
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": call.id,
|
|
"name": call.name,
|
|
"content": self._tool_result_for_model(call.name, result),
|
|
}
|
|
)
|
|
yield ChatEvent(type="tool_result", tool_name=call.name, tool_args=call.arguments, tool_result=result, iteration=iteration, max_iterations=self.max_iterations)
|
|
|
|
if repeated_loop_detected:
|
|
final = repeated_loop_message
|
|
self._finalize_turn(user_message, final, current_trace_turn)
|
|
self.last_messages = list(messages)
|
|
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="repeat_loop_detected")
|
|
return
|
|
|
|
messages[0]["content"] = self._system_prompt()
|
|
|
|
final = "Agent reached max_iterations before producing a final response."
|
|
self._finalize_turn(user_message, final, current_trace_turn)
|
|
self.last_messages = list(messages)
|
|
yield ChatEvent(type="final", final_response=final, iteration=self.max_iterations, max_iterations=self.max_iterations, final_reason="max_iterations")
|
|
|
|
def build_messages(self, user_message: str) -> List[Dict[str, Any]]:
|
|
recent_history = self.history[-self.max_history_turns * 6 :]
|
|
compression = self.context_compressor.compact(recent_history)
|
|
|
|
messages: List[Dict[str, Any]] = [{"role": "system", "content": self._system_prompt()}]
|
|
if compression.summary_message:
|
|
messages.append(compression.summary_message)
|
|
messages.extend(compression.tail_messages)
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
self.session_state["compression"] = {
|
|
"did_compact": compression.did_compact,
|
|
"estimated_tokens": compression.estimated_tokens,
|
|
"has_summary": bool(compression.summary_message),
|
|
}
|
|
return messages
|
|
|
|
def persisted_dialog_messages(self) -> List[Dict[str, str]]:
|
|
return list(self.history)
|
|
|
|
def _append_history(self, user_message: str, final_content: str) -> None:
|
|
self.history.append({"role": "user", "content": user_message})
|
|
self.history.append({"role": "assistant", "content": final_content})
|
|
self._trim_recent_history()
|
|
self._save_recent_history()
|
|
|
|
def _finalize_turn(self, user_message: str, final_content: str, trace_turn: ToolTraceTurn) -> None:
|
|
self._append_history(user_message, final_content)
|
|
self.tool_trace_store.finish_turn(trace_turn, final_content)
|
|
|
|
def _load_recent_history(self) -> List[Dict[str, str]]:
|
|
if not self.history_path.is_file():
|
|
return []
|
|
try:
|
|
raw = json.loads(self.history_path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
return []
|
|
items = raw.get("messages", []) if isinstance(raw, dict) else []
|
|
history: List[Dict[str, str]] = []
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
role = str(item.get("role", "")).strip()
|
|
content = str(item.get("content", "")).strip()
|
|
if role not in {"user", "assistant"} or not content:
|
|
continue
|
|
history.append({"role": role, "content": content})
|
|
return history[-self.max_persisted_turns * 2 :]
|
|
|
|
def _save_recent_history(self) -> None:
|
|
self.history_path.parent.mkdir(parents=True, exist_ok=True)
|
|
payload = {
|
|
"session_id": self.session_id,
|
|
"messages": self.history[-self.max_persisted_turns * 2 :],
|
|
"max_persisted_turns": self.max_persisted_turns,
|
|
}
|
|
self.history_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
def _trim_recent_history(self) -> None:
|
|
self.history = self.history[-self.max_persisted_turns * 2 :]
|
|
|
|
def _system_prompt(self) -> str:
|
|
return build_system_prompt(
|
|
skill_store=self.skill_store,
|
|
active_skills=self.session_state.get("active_skills", []),
|
|
tool_registry=self.tool_registry,
|
|
workspace=str(self.workspace),
|
|
base_prompt=self.system_prompt,
|
|
)
|
|
|
|
def _tool_result_for_model(self, tool_name: str, result: str) -> str:
|
|
text = str(result or "")
|
|
if len(text) <= self.max_tool_result_chars_for_model:
|
|
return text
|
|
try:
|
|
payload = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
compact_text = self._compact_tool_value(text, depth=0)
|
|
summary = {
|
|
"_tool_name": tool_name,
|
|
"_truncated_for_model": True,
|
|
"_original_chars": len(text),
|
|
"_note": "Full tool result was stored in tool_trace.json. Use tool_trace_query if more detail is needed.",
|
|
"result_summary": compact_text,
|
|
}
|
|
return json.dumps(summary, ensure_ascii=False, indent=2)
|
|
|
|
summary = {
|
|
"_tool_name": tool_name,
|
|
"_truncated_for_model": True,
|
|
"_original_chars": len(text),
|
|
"_note": "Full tool result was stored in tool_trace.json. Use tool_trace_query if more detail is needed.",
|
|
"result_summary": self._compact_tool_value(payload, depth=0),
|
|
}
|
|
return json.dumps(summary, ensure_ascii=False, indent=2)
|
|
|
|
def _compact_tool_value(self, value: Any, *, depth: int) -> Any:
|
|
if depth >= 4:
|
|
if isinstance(value, (dict, list)):
|
|
return f"[{type(value).__name__} truncated]"
|
|
return self._compact_scalar(value)
|
|
if isinstance(value, dict):
|
|
items = list(value.items())
|
|
compact: dict[str, Any] = {}
|
|
for index, (key, item) in enumerate(items):
|
|
if index >= self.max_tool_dict_items:
|
|
compact["_remaining_fields"] = len(items) - self.max_tool_dict_items
|
|
break
|
|
compact[str(key)] = self._compact_tool_value(item, depth=depth + 1)
|
|
return compact
|
|
if isinstance(value, list):
|
|
compact_list = [
|
|
self._compact_tool_value(item, depth=depth + 1)
|
|
for item in value[: self.max_tool_list_items]
|
|
]
|
|
if len(value) > self.max_tool_list_items:
|
|
compact_list.append(f"... {len(value) - self.max_tool_list_items} more items")
|
|
return compact_list
|
|
return self._compact_scalar(value)
|
|
|
|
def _compact_scalar(self, value: Any) -> Any:
|
|
if isinstance(value, str):
|
|
compact = " ".join(value.split())
|
|
if len(compact) > self.max_tool_string_field_chars:
|
|
return compact[: self.max_tool_string_field_chars] + "...[truncated]"
|
|
return compact
|
|
return value
|
|
|
|
|
|
def _assistant_message(turn: AssistantTurn) -> Dict[str, Any]:
|
|
message: Dict[str, Any] = {"role": "assistant", "content": turn.content or ""}
|
|
if turn.tool_calls:
|
|
message["tool_calls"] = [
|
|
{
|
|
"id": call.id,
|
|
"type": "function",
|
|
"function": {"name": call.name, "arguments": json.dumps(call.arguments, ensure_ascii=False)},
|
|
}
|
|
for call in turn.tool_calls
|
|
]
|
|
return message
|
|
|
|
|
|
def _tool_error_result(tool_name: str, tool_args: Dict[str, Any], exc: Exception) -> str:
|
|
payload = {
|
|
"success": False,
|
|
"error": f"{type(exc).__name__}: {exc}",
|
|
"tool_name": tool_name,
|
|
"tool_args": tool_args,
|
|
"retryable": False,
|
|
"traceback": traceback.format_exc(limit=8),
|
|
}
|
|
return json.dumps(payload, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def _sanitize_session_id(value: str) -> str:
|
|
text = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or "default")).strip("._")
|
|
return text or "default"
|