226 lines
8.2 KiB
Python
226 lines
8.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
from copy import deepcopy
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, List, Optional
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class ToolTraceEntry:
|
||
|
|
tool_name: str
|
||
|
|
arguments: Dict[str, Any]
|
||
|
|
thought: str = ""
|
||
|
|
result: str = ""
|
||
|
|
tool_call_id: str = ""
|
||
|
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||
|
|
|
||
|
|
def to_dict(self) -> Dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"tool_name": self.tool_name,
|
||
|
|
"arguments": deepcopy(self.arguments),
|
||
|
|
"thought": self.thought,
|
||
|
|
"result": self.result,
|
||
|
|
"tool_call_id": self.tool_call_id,
|
||
|
|
"created_at": self.created_at,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class ToolTraceTurn:
|
||
|
|
turn_id: int
|
||
|
|
started_at: str
|
||
|
|
user_message: str
|
||
|
|
|
||
|
|
|
||
|
|
class ToolTraceStore:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
path: str | Path,
|
||
|
|
*,
|
||
|
|
max_turns: int = 20,
|
||
|
|
max_result_chars: int = 24000,
|
||
|
|
max_thought_chars: int = 12000,
|
||
|
|
) -> None:
|
||
|
|
self.path = Path(path)
|
||
|
|
self.max_turns = max(1, int(max_turns))
|
||
|
|
self.max_result_chars = max(1000, int(max_result_chars))
|
||
|
|
self.max_thought_chars = max(1000, int(max_thought_chars))
|
||
|
|
self._next_turn_id = 1
|
||
|
|
self._bootstrap()
|
||
|
|
|
||
|
|
def begin_turn(self, user_message: str) -> ToolTraceTurn:
|
||
|
|
turn = ToolTraceTurn(
|
||
|
|
turn_id=self._next_turn_id,
|
||
|
|
started_at=datetime.now().isoformat(),
|
||
|
|
user_message=user_message,
|
||
|
|
)
|
||
|
|
payload = self._load_payload()
|
||
|
|
payload["active_turn"] = {
|
||
|
|
"turn_id": turn.turn_id,
|
||
|
|
"started_at": turn.started_at,
|
||
|
|
"completed_at": "",
|
||
|
|
"user_message": turn.user_message,
|
||
|
|
"assistant_final": "",
|
||
|
|
"tool_calls": [],
|
||
|
|
}
|
||
|
|
payload["next_turn_id"] = turn.turn_id + 1
|
||
|
|
self._next_turn_id = turn.turn_id + 1
|
||
|
|
self._save_payload(payload)
|
||
|
|
return turn
|
||
|
|
|
||
|
|
def record_tool_call(
|
||
|
|
self,
|
||
|
|
turn: ToolTraceTurn,
|
||
|
|
*,
|
||
|
|
tool_name: str,
|
||
|
|
arguments: Dict[str, Any],
|
||
|
|
thought: str = "",
|
||
|
|
result: str = "",
|
||
|
|
tool_call_id: str = "",
|
||
|
|
) -> None:
|
||
|
|
payload = self._load_payload()
|
||
|
|
active_turn = payload.get("active_turn")
|
||
|
|
if not isinstance(active_turn, dict) or int(active_turn.get("turn_id", -1)) != turn.turn_id:
|
||
|
|
return
|
||
|
|
|
||
|
|
entry = ToolTraceEntry(
|
||
|
|
tool_name=tool_name,
|
||
|
|
arguments=deepcopy(arguments),
|
||
|
|
thought=self._trim_text(thought, self.max_thought_chars),
|
||
|
|
result=self._trim_text(result, self.max_result_chars),
|
||
|
|
tool_call_id=tool_call_id,
|
||
|
|
)
|
||
|
|
tool_calls = active_turn.setdefault("tool_calls", [])
|
||
|
|
if isinstance(tool_calls, list):
|
||
|
|
tool_calls.append(entry.to_dict())
|
||
|
|
payload["active_turn"] = active_turn
|
||
|
|
self._save_payload(payload)
|
||
|
|
|
||
|
|
def finish_turn(self, turn: ToolTraceTurn, assistant_final: str) -> None:
|
||
|
|
payload = self._load_payload()
|
||
|
|
active_turn = payload.get("active_turn")
|
||
|
|
if not isinstance(active_turn, dict) or int(active_turn.get("turn_id", -1)) != turn.turn_id:
|
||
|
|
return
|
||
|
|
|
||
|
|
active_turn["assistant_final"] = self._trim_text(assistant_final, self.max_result_chars)
|
||
|
|
active_turn["completed_at"] = datetime.now().isoformat()
|
||
|
|
|
||
|
|
turns = payload.get("turns", [])
|
||
|
|
if not isinstance(turns, list):
|
||
|
|
turns = []
|
||
|
|
turns.append(active_turn)
|
||
|
|
payload["turns"] = turns[-self.max_turns :]
|
||
|
|
payload["active_turn"] = None
|
||
|
|
payload["next_turn_id"] = max(int(payload.get("next_turn_id", 1)), turn.turn_id + 1)
|
||
|
|
self._next_turn_id = int(payload["next_turn_id"])
|
||
|
|
self._save_payload(payload)
|
||
|
|
|
||
|
|
def query(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
tool_name: Optional[str] = None,
|
||
|
|
keyword: Optional[str] = None,
|
||
|
|
limit_turns: Optional[int] = None,
|
||
|
|
include_empty_turns: bool = False,
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
payload = self._load_payload()
|
||
|
|
turns = payload.get("turns", [])
|
||
|
|
if not isinstance(turns, list):
|
||
|
|
turns = []
|
||
|
|
|
||
|
|
name_filter = (tool_name or "").strip()
|
||
|
|
keyword_filter = (keyword or "").strip().lower()
|
||
|
|
selected_turns = turns[-limit_turns:] if limit_turns else list(turns)
|
||
|
|
|
||
|
|
results: List[Dict[str, Any]] = []
|
||
|
|
for turn in selected_turns:
|
||
|
|
if not isinstance(turn, dict):
|
||
|
|
continue
|
||
|
|
raw_calls = turn.get("tool_calls", [])
|
||
|
|
matched_calls: List[Dict[str, Any]] = []
|
||
|
|
if isinstance(raw_calls, list):
|
||
|
|
for item in raw_calls:
|
||
|
|
if not isinstance(item, dict):
|
||
|
|
continue
|
||
|
|
item_tool_name = str(item.get("tool_name", "")).strip()
|
||
|
|
haystack = json.dumps(item, ensure_ascii=False).lower()
|
||
|
|
if name_filter and item_tool_name != name_filter:
|
||
|
|
continue
|
||
|
|
if keyword_filter and keyword_filter not in haystack:
|
||
|
|
continue
|
||
|
|
matched_calls.append(item)
|
||
|
|
|
||
|
|
has_no_calls = not isinstance(raw_calls, list) or not raw_calls
|
||
|
|
if matched_calls or (include_empty_turns and has_no_calls and not name_filter and not keyword_filter):
|
||
|
|
results.append(
|
||
|
|
{
|
||
|
|
"turn_id": int(turn.get("turn_id", 0) or 0),
|
||
|
|
"started_at": str(turn.get("started_at", "")),
|
||
|
|
"completed_at": str(turn.get("completed_at", "")),
|
||
|
|
"user_message": str(turn.get("user_message", "")),
|
||
|
|
"assistant_final": str(turn.get("assistant_final", "")),
|
||
|
|
"tool_calls": matched_calls,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"success": True,
|
||
|
|
"max_turns": self.max_turns,
|
||
|
|
"stored_turns": len(turns),
|
||
|
|
"matched_turns": len(results),
|
||
|
|
"has_active_turn": isinstance(payload.get("active_turn"), dict),
|
||
|
|
"turns": results,
|
||
|
|
}
|
||
|
|
|
||
|
|
def _bootstrap(self) -> None:
|
||
|
|
payload = self._load_payload()
|
||
|
|
turns = payload.get("turns", [])
|
||
|
|
active_turn = payload.get("active_turn")
|
||
|
|
|
||
|
|
if not isinstance(turns, list):
|
||
|
|
turns = []
|
||
|
|
payload["turns"] = turns[-self.max_turns :]
|
||
|
|
|
||
|
|
next_turn_id = payload.get("next_turn_id")
|
||
|
|
if isinstance(next_turn_id, int) and next_turn_id > 0:
|
||
|
|
self._next_turn_id = next_turn_id
|
||
|
|
else:
|
||
|
|
max_turn_id = 0
|
||
|
|
for item in payload["turns"]:
|
||
|
|
if isinstance(item, dict):
|
||
|
|
max_turn_id = max(max_turn_id, int(item.get("turn_id", 0) or 0))
|
||
|
|
if isinstance(active_turn, dict):
|
||
|
|
max_turn_id = max(max_turn_id, int(active_turn.get("turn_id", 0) or 0))
|
||
|
|
self._next_turn_id = max_turn_id + 1
|
||
|
|
payload["next_turn_id"] = self._next_turn_id
|
||
|
|
|
||
|
|
self._save_payload(payload)
|
||
|
|
|
||
|
|
def _load_payload(self) -> Dict[str, Any]:
|
||
|
|
if not self.path.is_file():
|
||
|
|
return {"max_turns": self.max_turns, "next_turn_id": self._next_turn_id, "active_turn": None, "turns": []}
|
||
|
|
try:
|
||
|
|
raw = json.loads(self.path.read_text(encoding="utf-8"))
|
||
|
|
except (OSError, json.JSONDecodeError):
|
||
|
|
return {"max_turns": self.max_turns, "next_turn_id": self._next_turn_id, "active_turn": None, "turns": []}
|
||
|
|
if not isinstance(raw, dict):
|
||
|
|
return {"max_turns": self.max_turns, "next_turn_id": self._next_turn_id, "active_turn": None, "turns": []}
|
||
|
|
raw.setdefault("active_turn", None)
|
||
|
|
raw.setdefault("turns", [])
|
||
|
|
raw.setdefault("next_turn_id", self._next_turn_id)
|
||
|
|
raw["max_turns"] = self.max_turns
|
||
|
|
return raw
|
||
|
|
|
||
|
|
def _save_payload(self, payload: Dict[str, Any]) -> None:
|
||
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
self.path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
|
|
|
||
|
|
def _trim_text(self, text: str, limit: int) -> str:
|
||
|
|
value = str(text or "")
|
||
|
|
if len(value) <= limit:
|
||
|
|
return value
|
||
|
|
return value[:limit] + "\n...[truncated]"
|