meeting_memory/core_agent/session.py

321 lines
14 KiB
Python
Raw Normal View History

2026-06-24 07:05:19 +00:00
from __future__ import annotations
from dataclasses import dataclass, field
import json
from pathlib import Path
import re
import traceback
from typing import Any, Dict, Iterator, List, Optional
from core_agent.compression import RollingContextCompressor
from prompts import DEFAULT_SYSTEM_PROMPT, build_system_prompt
from providers.base import AgentProvider, AssistantTurn
from skills import SkillStore
from tools.default_tools import build_default_registry
from tools.registry import ToolContext, ToolRegistry
from tools.tool_trace import ToolTraceStore, ToolTraceTurn
@dataclass(slots=True)
class ChatEvent:
type: str
delta: str = ""
tool_name: str = ""
tool_args: Dict[str, Any] = field(default_factory=dict)
tool_result: str = ""
final_response: str = ""
turn_content: str = ""
turn_reasoning: str = ""
iteration: int = 0
max_iterations: int = 0
final_reason: str = ""
@dataclass(slots=True)
class AgentRunResult:
final_response: str
messages: List[Dict[str, Any]]
session: Dict[str, Any] = field(default_factory=dict)
from core_agent.exceptions import AgentCancelled
class ConversationSession:
def __init__(
self,
*,
provider: AgentProvider,
workspace: str | Path,
session_id: str = "default",
skill_dirs: Optional[List[str | Path]] = None,
tool_registry: Optional[ToolRegistry] = None,
system_prompt: Optional[str] = None,
max_iterations: int = 12,
max_history_turns: int = 20,
max_persisted_turns: int = 20,
active_skills: Optional[List[str]] = None,
) -> None:
self.provider = provider
self.workspace = Path(workspace).resolve()
self.session_id = _sanitize_session_id(session_id)
self.skill_store = SkillStore(skill_dirs or [self.workspace / "skills"])
self.tool_registry = tool_registry or build_default_registry(self.workspace / "data")
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
self.max_iterations = max_iterations
self.max_history_turns = max_history_turns
self.max_persisted_turns = max_persisted_turns
self.session_dir = self.workspace / "agent_memory" / "sessions" / self.session_id
self.history_path = self.session_dir / "recent_history.json"
self.tool_trace_store = ToolTraceStore(self.session_dir / "tool_trace.json", max_turns=max_history_turns)
self.context_compressor = RollingContextCompressor()
self.session_state: Dict[str, Any] = {
"workspace": str(self.workspace),
"active_skills": list(active_skills or []),
"session_id": self.session_id,
"_cancel_requested": False,
}
self.history: List[Dict[str, Any]] = self._load_recent_history()
self.last_messages: List[Dict[str, Any]] = []
def ask(self, user_message: str) -> AgentRunResult:
final: ChatEvent | None = None
for event in self.stream_ask(user_message):
final = event
if final is None or final.type != "final":
raise RuntimeError("Conversation ended without a final response")
return AgentRunResult(final_response=final.final_response, messages=list(self.last_messages), session=dict(self.session_state))
def request_cancel(self) -> None:
self.session_state["_cancel_requested"] = True
def clear_cancel(self) -> None:
self.session_state["_cancel_requested"] = False
def cancellation_requested(self) -> bool:
return bool(self.session_state.get("_cancel_requested"))
def _raise_if_cancelled(self) -> None:
if self.cancellation_requested():
raise AgentCancelled("Agent run cancelled by user.")
def stream_ask(self, user_message: str) -> Iterator[ChatEvent]:
self.clear_cancel()
messages = self.build_messages(user_message)
self.last_messages = list(messages)
current_trace_turn = self.tool_trace_store.begin_turn(user_message)
repeat_counter: dict[str, int] = {}
_REPEAT_LIMIT = 3
tool_name_counter: dict[str, int] = {}
_TOOL_NAME_LIMIT = 6
for iteration in range(1, self.max_iterations + 1):
self._raise_if_cancelled()
yield ChatEvent(type="round_start", iteration=iteration, max_iterations=self.max_iterations)
try:
turn: AssistantTurn | None = None
for stream_event in self.provider.stream_generate(messages, self.tool_registry.definitions(), should_cancel=self.cancellation_requested):
self._raise_if_cancelled()
if stream_event.type == "reasoning" and stream_event.delta:
yield ChatEvent(
type="reasoning_delta",
delta=stream_event.delta,
iteration=iteration,
max_iterations=self.max_iterations,
)
elif stream_event.type == "content" and stream_event.delta:
yield ChatEvent(
type="content_delta",
delta=stream_event.delta,
iteration=iteration,
max_iterations=self.max_iterations,
)
elif stream_event.type == "turn":
turn = stream_event.turn
if turn is None:
raise RuntimeError("Provider stream ended without a final turn")
self._raise_if_cancelled()
except AgentCancelled as exc:
final = str(exc)
self.last_messages = list(messages)
self.tool_trace_store.finish_turn(current_trace_turn, final)
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="cancelled")
return
except Exception as exc:
final = str(exc)
self.last_messages = list(messages)
self.tool_trace_store.finish_turn(current_trace_turn, final)
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="provider_error")
return
yield ChatEvent(
type="assistant_turn",
turn_content=turn.content or "",
turn_reasoning=turn.reasoning or "",
iteration=iteration,
max_iterations=self.max_iterations,
)
messages.append(_assistant_message(turn))
if not turn.tool_calls:
repeat_counter.clear()
tool_name_counter.clear()
final = turn.content or ""
self._finalize_turn(user_message, final, current_trace_turn)
self.last_messages = list(messages)
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="completed")
return
ctx = ToolContext(workspace=self.workspace, session=self.session_state, tool_trace_store=self.tool_trace_store)
repeated_loop_detected = False
repeated_loop_message = ""
for call in turn.tool_calls:
self._raise_if_cancelled()
call_key = json.dumps({"name": call.name, "args": call.arguments}, sort_keys=True, ensure_ascii=False)
repeat_counter[call_key] = repeat_counter.get(call_key, 0) + 1
if repeat_counter[call_key] >= _REPEAT_LIMIT:
repeated_loop_detected = True
repeated_loop_message = "检测到重复工具调用(相同参数连续调用 {0} 次),已自动终止循环。请检查工具参数或换用更大参数的模型。".format(_REPEAT_LIMIT)
break
tool_name_counter[call.name] = tool_name_counter.get(call.name, 0) + 1
if tool_name_counter[call.name] >= _TOOL_NAME_LIMIT:
repeated_loop_detected = True
repeated_loop_message = "检测到同一工具「{0}」被连续调用 {1} 次(参数略有不同但未产生有效进展),已自动终止循环。建议换用更大参数的模型。".format(call.name, _TOOL_NAME_LIMIT)
break
yield ChatEvent(type="tool_call", tool_name=call.name, tool_args=call.arguments, iteration=iteration, max_iterations=self.max_iterations)
try:
result = self.tool_registry.execute(call.name, call.arguments, ctx)
except AgentCancelled as exc:
final = str(exc)
self.last_messages = list(messages)
self.tool_trace_store.finish_turn(current_trace_turn, final)
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="cancelled")
return
except Exception as exc:
result = _tool_error_result(call.name, call.arguments, exc)
self.tool_trace_store.record_tool_call(
current_trace_turn,
tool_name=call.name,
arguments=call.arguments,
thought=turn.reasoning or "",
result=result,
tool_call_id=call.id,
)
messages.append({"role": "tool", "tool_call_id": call.id, "name": call.name, "content": result})
yield ChatEvent(type="tool_result", tool_name=call.name, tool_args=call.arguments, tool_result=result, iteration=iteration, max_iterations=self.max_iterations)
if repeated_loop_detected:
final = repeated_loop_message
self._finalize_turn(user_message, final, current_trace_turn)
self.last_messages = list(messages)
yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="repeat_loop_detected")
return
messages[0]["content"] = self._system_prompt()
final = "Agent reached max_iterations before producing a final response."
self._finalize_turn(user_message, final, current_trace_turn)
self.last_messages = list(messages)
yield ChatEvent(type="final", final_response=final, iteration=self.max_iterations, max_iterations=self.max_iterations, final_reason="max_iterations")
def build_messages(self, user_message: str) -> List[Dict[str, Any]]:
recent_history = self.history[-self.max_history_turns * 6 :]
compression = self.context_compressor.compact(recent_history)
messages: List[Dict[str, Any]] = [{"role": "system", "content": self._system_prompt()}]
if compression.summary_message:
messages.append(compression.summary_message)
messages.extend(compression.tail_messages)
messages.append({"role": "user", "content": user_message})
self.session_state["compression"] = {
"did_compact": compression.did_compact,
"estimated_tokens": compression.estimated_tokens,
"has_summary": bool(compression.summary_message),
}
return messages
def persisted_dialog_messages(self) -> List[Dict[str, str]]:
return list(self.history)
def _append_history(self, user_message: str, final_content: str) -> None:
self.history.append({"role": "user", "content": user_message})
self.history.append({"role": "assistant", "content": final_content})
self._trim_recent_history()
self._save_recent_history()
def _finalize_turn(self, user_message: str, final_content: str, trace_turn: ToolTraceTurn) -> None:
self._append_history(user_message, final_content)
self.tool_trace_store.finish_turn(trace_turn, final_content)
def _load_recent_history(self) -> List[Dict[str, str]]:
if not self.history_path.is_file():
return []
try:
raw = json.loads(self.history_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return []
items = raw.get("messages", []) if isinstance(raw, dict) else []
history: List[Dict[str, str]] = []
for item in items:
if not isinstance(item, dict):
continue
role = str(item.get("role", "")).strip()
content = str(item.get("content", "")).strip()
if role not in {"user", "assistant"} or not content:
continue
history.append({"role": role, "content": content})
return history[-self.max_persisted_turns * 2 :]
def _save_recent_history(self) -> None:
self.history_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"session_id": self.session_id,
"messages": self.history[-self.max_persisted_turns * 2 :],
"max_persisted_turns": self.max_persisted_turns,
}
self.history_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _trim_recent_history(self) -> None:
self.history = self.history[-self.max_persisted_turns * 2 :]
def _system_prompt(self) -> str:
return build_system_prompt(
skill_store=self.skill_store,
active_skills=self.session_state.get("active_skills", []),
tool_registry=self.tool_registry,
workspace=str(self.workspace),
base_prompt=self.system_prompt,
)
def _assistant_message(turn: AssistantTurn) -> Dict[str, Any]:
message: Dict[str, Any] = {"role": "assistant", "content": turn.content or ""}
if turn.tool_calls:
message["tool_calls"] = [
{
"id": call.id,
"type": "function",
"function": {"name": call.name, "arguments": json.dumps(call.arguments, ensure_ascii=False)},
}
for call in turn.tool_calls
]
return message
def _tool_error_result(tool_name: str, tool_args: Dict[str, Any], exc: Exception) -> str:
payload = {
"success": False,
"error": f"{type(exc).__name__}: {exc}",
"tool_name": tool_name,
"tool_args": tool_args,
"retryable": False,
"traceback": traceback.format_exc(limit=8),
}
return json.dumps(payload, ensure_ascii=False, indent=2)
def _sanitize_session_id(value: str) -> str:
text = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or "default")).strip("._")
return text or "default"