from __future__ import annotations from dataclasses import dataclass, field import json from pathlib import Path import re import traceback from typing import Any, Dict, Iterator, List, Optional from core_agent.compression import RollingContextCompressor from prompts import DEFAULT_SYSTEM_PROMPT, build_system_prompt from providers.base import AgentProvider, AssistantTurn from skills import SkillStore from tools.default_tools import build_default_registry from tools.registry import ToolContext, ToolRegistry from tools.tool_trace import ToolTraceStore, ToolTraceTurn @dataclass(slots=True) class ChatEvent: type: str delta: str = "" tool_name: str = "" tool_args: Dict[str, Any] = field(default_factory=dict) tool_result: str = "" final_response: str = "" turn_content: str = "" turn_reasoning: str = "" iteration: int = 0 max_iterations: int = 0 final_reason: str = "" @dataclass(slots=True) class AgentRunResult: final_response: str messages: List[Dict[str, Any]] session: Dict[str, Any] = field(default_factory=dict) from core_agent.exceptions import AgentCancelled class ConversationSession: def __init__( self, *, provider: AgentProvider, workspace: str | Path, session_id: str = "default", skill_dirs: Optional[List[str | Path]] = None, tool_registry: Optional[ToolRegistry] = None, system_prompt: Optional[str] = None, max_iterations: int = 12, max_history_turns: int = 20, max_persisted_turns: int = 20, active_skills: Optional[List[str]] = None, ) -> None: self.provider = provider self.workspace = Path(workspace).resolve() self.session_id = _sanitize_session_id(session_id) self.skill_store = SkillStore(skill_dirs or [self.workspace / "skills"]) self.tool_registry = tool_registry or build_default_registry(self.workspace / "data") self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT self.max_iterations = max_iterations self.max_history_turns = max_history_turns self.max_persisted_turns = max_persisted_turns self.max_tool_result_chars_for_model = 12000 self.max_tool_string_field_chars = 1500 self.max_tool_list_items = 8 self.max_tool_dict_items = 20 self.session_dir = self.workspace / "agent_memory" / "sessions" / self.session_id self.history_path = self.session_dir / "recent_history.json" self.tool_trace_store = ToolTraceStore(self.session_dir / "tool_trace.json", max_turns=max_history_turns) self.context_compressor = RollingContextCompressor() self.session_state: Dict[str, Any] = { "workspace": str(self.workspace), "active_skills": list(active_skills or []), "session_id": self.session_id, "_cancel_requested": False, } self.history: List[Dict[str, Any]] = self._load_recent_history() self.last_messages: List[Dict[str, Any]] = [] def ask(self, user_message: str) -> AgentRunResult: final: ChatEvent | None = None for event in self.stream_ask(user_message): final = event if final is None or final.type != "final": raise RuntimeError("Conversation ended without a final response") return AgentRunResult(final_response=final.final_response, messages=list(self.last_messages), session=dict(self.session_state)) def request_cancel(self) -> None: self.session_state["_cancel_requested"] = True def clear_cancel(self) -> None: self.session_state["_cancel_requested"] = False def cancellation_requested(self) -> bool: return bool(self.session_state.get("_cancel_requested")) def _raise_if_cancelled(self) -> None: if self.cancellation_requested(): raise AgentCancelled("Agent run cancelled by user.") def stream_ask(self, user_message: str) -> Iterator[ChatEvent]: self.clear_cancel() messages = self.build_messages(user_message) self.last_messages = list(messages) current_trace_turn = self.tool_trace_store.begin_turn(user_message) repeat_counter: dict[str, int] = {} _REPEAT_LIMIT = 3 tool_name_counter: dict[str, int] = {} _TOOL_NAME_LIMIT = 6 for iteration in range(1, self.max_iterations + 1): self._raise_if_cancelled() yield ChatEvent(type="round_start", iteration=iteration, max_iterations=self.max_iterations) try: turn: AssistantTurn | None = None for stream_event in self.provider.stream_generate(messages, self.tool_registry.definitions(), should_cancel=self.cancellation_requested): self._raise_if_cancelled() if stream_event.type == "reasoning" and stream_event.delta: yield ChatEvent( type="reasoning_delta", delta=stream_event.delta, iteration=iteration, max_iterations=self.max_iterations, ) elif stream_event.type == "content" and stream_event.delta: yield ChatEvent( type="content_delta", delta=stream_event.delta, iteration=iteration, max_iterations=self.max_iterations, ) elif stream_event.type == "turn": turn = stream_event.turn if turn is None: raise RuntimeError("Provider stream ended without a final turn") self._raise_if_cancelled() except AgentCancelled as exc: final = str(exc) self.last_messages = list(messages) self.tool_trace_store.finish_turn(current_trace_turn, final) yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="cancelled") return except Exception as exc: final = str(exc) self.last_messages = list(messages) self.tool_trace_store.finish_turn(current_trace_turn, final) yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="provider_error") return yield ChatEvent( type="assistant_turn", turn_content=turn.content or "", turn_reasoning=turn.reasoning or "", iteration=iteration, max_iterations=self.max_iterations, ) messages.append(_assistant_message(turn)) if not turn.tool_calls: repeat_counter.clear() tool_name_counter.clear() final = turn.content or "" self._finalize_turn(user_message, final, current_trace_turn) self.last_messages = list(messages) yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="completed") return ctx = ToolContext(workspace=self.workspace, session=self.session_state, tool_trace_store=self.tool_trace_store) repeated_loop_detected = False repeated_loop_message = "" for call in turn.tool_calls: self._raise_if_cancelled() call_key = json.dumps({"name": call.name, "args": call.arguments}, sort_keys=True, ensure_ascii=False) repeat_counter[call_key] = repeat_counter.get(call_key, 0) + 1 if repeat_counter[call_key] >= _REPEAT_LIMIT: repeated_loop_detected = True repeated_loop_message = "检测到重复工具调用(相同参数连续调用 {0} 次),已自动终止循环。请检查工具参数或换用更大参数的模型。".format(_REPEAT_LIMIT) break tool_name_counter[call.name] = tool_name_counter.get(call.name, 0) + 1 if tool_name_counter[call.name] >= _TOOL_NAME_LIMIT: repeated_loop_detected = True repeated_loop_message = "检测到同一工具「{0}」被连续调用 {1} 次(参数略有不同但未产生有效进展),已自动终止循环。建议换用更大参数的模型。".format(call.name, _TOOL_NAME_LIMIT) break yield ChatEvent(type="tool_call", tool_name=call.name, tool_args=call.arguments, iteration=iteration, max_iterations=self.max_iterations) try: result = self.tool_registry.execute(call.name, call.arguments, ctx) except AgentCancelled as exc: final = str(exc) self.last_messages = list(messages) self.tool_trace_store.finish_turn(current_trace_turn, final) yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="cancelled") return except Exception as exc: result = _tool_error_result(call.name, call.arguments, exc) self.tool_trace_store.record_tool_call( current_trace_turn, tool_name=call.name, arguments=call.arguments, thought=turn.reasoning or "", result=result, tool_call_id=call.id, ) messages.append( { "role": "tool", "tool_call_id": call.id, "name": call.name, "content": self._tool_result_for_model(call.name, result), } ) yield ChatEvent(type="tool_result", tool_name=call.name, tool_args=call.arguments, tool_result=result, iteration=iteration, max_iterations=self.max_iterations) if repeated_loop_detected: final = repeated_loop_message self._finalize_turn(user_message, final, current_trace_turn) self.last_messages = list(messages) yield ChatEvent(type="final", final_response=final, iteration=iteration, max_iterations=self.max_iterations, final_reason="repeat_loop_detected") return messages[0]["content"] = self._system_prompt() final = "Agent reached max_iterations before producing a final response." self._finalize_turn(user_message, final, current_trace_turn) self.last_messages = list(messages) yield ChatEvent(type="final", final_response=final, iteration=self.max_iterations, max_iterations=self.max_iterations, final_reason="max_iterations") def build_messages(self, user_message: str) -> List[Dict[str, Any]]: recent_history = self.history[-self.max_history_turns * 6 :] compression = self.context_compressor.compact(recent_history) messages: List[Dict[str, Any]] = [{"role": "system", "content": self._system_prompt()}] if compression.summary_message: messages.append(compression.summary_message) messages.extend(compression.tail_messages) messages.append({"role": "user", "content": user_message}) self.session_state["compression"] = { "did_compact": compression.did_compact, "estimated_tokens": compression.estimated_tokens, "has_summary": bool(compression.summary_message), } return messages def persisted_dialog_messages(self) -> List[Dict[str, str]]: return list(self.history) def _append_history(self, user_message: str, final_content: str) -> None: self.history.append({"role": "user", "content": user_message}) self.history.append({"role": "assistant", "content": final_content}) self._trim_recent_history() self._save_recent_history() def _finalize_turn(self, user_message: str, final_content: str, trace_turn: ToolTraceTurn) -> None: self._append_history(user_message, final_content) self.tool_trace_store.finish_turn(trace_turn, final_content) def _load_recent_history(self) -> List[Dict[str, str]]: if not self.history_path.is_file(): return [] try: raw = json.loads(self.history_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return [] items = raw.get("messages", []) if isinstance(raw, dict) else [] history: List[Dict[str, str]] = [] for item in items: if not isinstance(item, dict): continue role = str(item.get("role", "")).strip() content = str(item.get("content", "")).strip() if role not in {"user", "assistant"} or not content: continue history.append({"role": role, "content": content}) return history[-self.max_persisted_turns * 2 :] def _save_recent_history(self) -> None: self.history_path.parent.mkdir(parents=True, exist_ok=True) payload = { "session_id": self.session_id, "messages": self.history[-self.max_persisted_turns * 2 :], "max_persisted_turns": self.max_persisted_turns, } self.history_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") def _trim_recent_history(self) -> None: self.history = self.history[-self.max_persisted_turns * 2 :] def _system_prompt(self) -> str: return build_system_prompt( skill_store=self.skill_store, active_skills=self.session_state.get("active_skills", []), tool_registry=self.tool_registry, workspace=str(self.workspace), base_prompt=self.system_prompt, ) def _tool_result_for_model(self, tool_name: str, result: str) -> str: text = str(result or "") if len(text) <= self.max_tool_result_chars_for_model: return text try: payload = json.loads(text) except json.JSONDecodeError: compact_text = self._compact_tool_value(text, depth=0) summary = { "_tool_name": tool_name, "_truncated_for_model": True, "_original_chars": len(text), "_note": "Full tool result was stored in tool_trace.json. Use tool_trace_query if more detail is needed.", "result_summary": compact_text, } return json.dumps(summary, ensure_ascii=False, indent=2) summary = { "_tool_name": tool_name, "_truncated_for_model": True, "_original_chars": len(text), "_note": "Full tool result was stored in tool_trace.json. Use tool_trace_query if more detail is needed.", "result_summary": self._compact_tool_value(payload, depth=0), } return json.dumps(summary, ensure_ascii=False, indent=2) def _compact_tool_value(self, value: Any, *, depth: int) -> Any: if depth >= 4: if isinstance(value, (dict, list)): return f"[{type(value).__name__} truncated]" return self._compact_scalar(value) if isinstance(value, dict): items = list(value.items()) compact: dict[str, Any] = {} for index, (key, item) in enumerate(items): if index >= self.max_tool_dict_items: compact["_remaining_fields"] = len(items) - self.max_tool_dict_items break compact[str(key)] = self._compact_tool_value(item, depth=depth + 1) return compact if isinstance(value, list): compact_list = [ self._compact_tool_value(item, depth=depth + 1) for item in value[: self.max_tool_list_items] ] if len(value) > self.max_tool_list_items: compact_list.append(f"... {len(value) - self.max_tool_list_items} more items") return compact_list return self._compact_scalar(value) def _compact_scalar(self, value: Any) -> Any: if isinstance(value, str): compact = " ".join(value.split()) if len(compact) > self.max_tool_string_field_chars: return compact[: self.max_tool_string_field_chars] + "...[truncated]" return compact return value def _assistant_message(turn: AssistantTurn) -> Dict[str, Any]: message: Dict[str, Any] = {"role": "assistant", "content": turn.content or ""} if turn.tool_calls: message["tool_calls"] = [ { "id": call.id, "type": "function", "function": {"name": call.name, "arguments": json.dumps(call.arguments, ensure_ascii=False)}, } for call in turn.tool_calls ] return message def _tool_error_result(tool_name: str, tool_args: Dict[str, Any], exc: Exception) -> str: payload = { "success": False, "error": f"{type(exc).__name__}: {exc}", "tool_name": tool_name, "tool_args": tool_args, "retryable": False, "traceback": traceback.format_exc(limit=8), } return json.dumps(payload, ensure_ascii=False, indent=2) def _sanitize_session_id(value: str) -> str: text = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or "default")).strip("._") return text or "default"