meeting_memory/tools/tool_trace.py

226 lines
8.2 KiB
Python
Raw Normal View History

2026-06-24 07:05:19 +00:00
from __future__ import annotations
import json
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass(slots=True)
class ToolTraceEntry:
tool_name: str
arguments: Dict[str, Any]
thought: str = ""
result: str = ""
tool_call_id: str = ""
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
return {
"tool_name": self.tool_name,
"arguments": deepcopy(self.arguments),
"thought": self.thought,
"result": self.result,
"tool_call_id": self.tool_call_id,
"created_at": self.created_at,
}
@dataclass(slots=True)
class ToolTraceTurn:
turn_id: int
started_at: str
user_message: str
class ToolTraceStore:
def __init__(
self,
path: str | Path,
*,
max_turns: int = 20,
max_result_chars: int = 24000,
max_thought_chars: int = 12000,
) -> None:
self.path = Path(path)
self.max_turns = max(1, int(max_turns))
self.max_result_chars = max(1000, int(max_result_chars))
self.max_thought_chars = max(1000, int(max_thought_chars))
self._next_turn_id = 1
self._bootstrap()
def begin_turn(self, user_message: str) -> ToolTraceTurn:
turn = ToolTraceTurn(
turn_id=self._next_turn_id,
started_at=datetime.now().isoformat(),
user_message=user_message,
)
payload = self._load_payload()
payload["active_turn"] = {
"turn_id": turn.turn_id,
"started_at": turn.started_at,
"completed_at": "",
"user_message": turn.user_message,
"assistant_final": "",
"tool_calls": [],
}
payload["next_turn_id"] = turn.turn_id + 1
self._next_turn_id = turn.turn_id + 1
self._save_payload(payload)
return turn
def record_tool_call(
self,
turn: ToolTraceTurn,
*,
tool_name: str,
arguments: Dict[str, Any],
thought: str = "",
result: str = "",
tool_call_id: str = "",
) -> None:
payload = self._load_payload()
active_turn = payload.get("active_turn")
if not isinstance(active_turn, dict) or int(active_turn.get("turn_id", -1)) != turn.turn_id:
return
entry = ToolTraceEntry(
tool_name=tool_name,
arguments=deepcopy(arguments),
thought=self._trim_text(thought, self.max_thought_chars),
result=self._trim_text(result, self.max_result_chars),
tool_call_id=tool_call_id,
)
tool_calls = active_turn.setdefault("tool_calls", [])
if isinstance(tool_calls, list):
tool_calls.append(entry.to_dict())
payload["active_turn"] = active_turn
self._save_payload(payload)
def finish_turn(self, turn: ToolTraceTurn, assistant_final: str) -> None:
payload = self._load_payload()
active_turn = payload.get("active_turn")
if not isinstance(active_turn, dict) or int(active_turn.get("turn_id", -1)) != turn.turn_id:
return
active_turn["assistant_final"] = self._trim_text(assistant_final, self.max_result_chars)
active_turn["completed_at"] = datetime.now().isoformat()
turns = payload.get("turns", [])
if not isinstance(turns, list):
turns = []
turns.append(active_turn)
payload["turns"] = turns[-self.max_turns :]
payload["active_turn"] = None
payload["next_turn_id"] = max(int(payload.get("next_turn_id", 1)), turn.turn_id + 1)
self._next_turn_id = int(payload["next_turn_id"])
self._save_payload(payload)
def query(
self,
*,
tool_name: Optional[str] = None,
keyword: Optional[str] = None,
limit_turns: Optional[int] = None,
include_empty_turns: bool = False,
) -> Dict[str, Any]:
payload = self._load_payload()
turns = payload.get("turns", [])
if not isinstance(turns, list):
turns = []
name_filter = (tool_name or "").strip()
keyword_filter = (keyword or "").strip().lower()
selected_turns = turns[-limit_turns:] if limit_turns else list(turns)
results: List[Dict[str, Any]] = []
for turn in selected_turns:
if not isinstance(turn, dict):
continue
raw_calls = turn.get("tool_calls", [])
matched_calls: List[Dict[str, Any]] = []
if isinstance(raw_calls, list):
for item in raw_calls:
if not isinstance(item, dict):
continue
item_tool_name = str(item.get("tool_name", "")).strip()
haystack = json.dumps(item, ensure_ascii=False).lower()
if name_filter and item_tool_name != name_filter:
continue
if keyword_filter and keyword_filter not in haystack:
continue
matched_calls.append(item)
has_no_calls = not isinstance(raw_calls, list) or not raw_calls
if matched_calls or (include_empty_turns and has_no_calls and not name_filter and not keyword_filter):
results.append(
{
"turn_id": int(turn.get("turn_id", 0) or 0),
"started_at": str(turn.get("started_at", "")),
"completed_at": str(turn.get("completed_at", "")),
"user_message": str(turn.get("user_message", "")),
"assistant_final": str(turn.get("assistant_final", "")),
"tool_calls": matched_calls,
}
)
return {
"success": True,
"max_turns": self.max_turns,
"stored_turns": len(turns),
"matched_turns": len(results),
"has_active_turn": isinstance(payload.get("active_turn"), dict),
"turns": results,
}
def _bootstrap(self) -> None:
payload = self._load_payload()
turns = payload.get("turns", [])
active_turn = payload.get("active_turn")
if not isinstance(turns, list):
turns = []
payload["turns"] = turns[-self.max_turns :]
next_turn_id = payload.get("next_turn_id")
if isinstance(next_turn_id, int) and next_turn_id > 0:
self._next_turn_id = next_turn_id
else:
max_turn_id = 0
for item in payload["turns"]:
if isinstance(item, dict):
max_turn_id = max(max_turn_id, int(item.get("turn_id", 0) or 0))
if isinstance(active_turn, dict):
max_turn_id = max(max_turn_id, int(active_turn.get("turn_id", 0) or 0))
self._next_turn_id = max_turn_id + 1
payload["next_turn_id"] = self._next_turn_id
self._save_payload(payload)
def _load_payload(self) -> Dict[str, Any]:
if not self.path.is_file():
return {"max_turns": self.max_turns, "next_turn_id": self._next_turn_id, "active_turn": None, "turns": []}
try:
raw = json.loads(self.path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {"max_turns": self.max_turns, "next_turn_id": self._next_turn_id, "active_turn": None, "turns": []}
if not isinstance(raw, dict):
return {"max_turns": self.max_turns, "next_turn_id": self._next_turn_id, "active_turn": None, "turns": []}
raw.setdefault("active_turn", None)
raw.setdefault("turns", [])
raw.setdefault("next_turn_id", self._next_turn_id)
raw["max_turns"] = self.max_turns
return raw
def _save_payload(self, payload: Dict[str, Any]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _trim_text(self, text: str, limit: int) -> str:
value = str(text or "")
if len(value) <= limit:
return value
return value[:limit] + "\n...[truncated]"