321 lines
11 KiB
Python
321 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import fnmatch
|
|
import os
|
|
import platform
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from .registry import ToolContext, ToolRegistry
|
|
|
|
|
|
def build_general_registry() -> ToolRegistry:
|
|
registry = ToolRegistry()
|
|
registry.register(
|
|
name="current_time",
|
|
description="获取当前本地时间。",
|
|
parameters={"type": "object", "properties": {}, "additionalProperties": False},
|
|
handler=_current_time,
|
|
)
|
|
registry.register(
|
|
name="list_directory",
|
|
description="列出工作区内某个目录下的文件和子目录。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {"type": "string", "description": "相对工作区的目录路径,默认当前工作区。"},
|
|
},
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_list_directory,
|
|
)
|
|
registry.register(
|
|
name="search_files",
|
|
description="在工作区内按文件名搜索文件,也可限定目录。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"pattern": {"type": "string", "description": "文件名模式,例如 *.py、*ledger*。"},
|
|
"path": {"type": "string", "description": "起始目录,默认当前工作区。"},
|
|
"limit": {"type": "integer", "minimum": 1, "maximum": 200},
|
|
},
|
|
"required": ["pattern"],
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_search_files,
|
|
)
|
|
registry.register(
|
|
name="read_file",
|
|
description="读取工作区内 UTF-8 文本文件。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {"path": {"type": "string"}},
|
|
"required": ["path"],
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_read_file,
|
|
)
|
|
registry.register(
|
|
name="write_file",
|
|
description="写入工作区内 UTF-8 文本文件。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {"type": "string"},
|
|
"content": {"type": "string"},
|
|
},
|
|
"required": ["path", "content"],
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_write_file,
|
|
)
|
|
registry.register(
|
|
name="execute_shell",
|
|
description="执行一条 shell 命令,用于环境检查、项目调试和只读类命令行操作。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"command": {"type": "string"},
|
|
"timeout_seconds": {"type": "integer", "minimum": 1, "maximum": 120},
|
|
},
|
|
"required": ["command"],
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_execute_shell,
|
|
)
|
|
registry.register(
|
|
name="run_python",
|
|
description="运行一段简短 Python 代码,用于调试、数据检查或快速转换。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"code": {"type": "string"},
|
|
"timeout_seconds": {"type": "integer", "minimum": 1, "maximum": 120},
|
|
},
|
|
"required": ["code"],
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_run_python,
|
|
)
|
|
registry.register(
|
|
name="tool_trace_query",
|
|
description="查询本会话已落盘的工具调用记录,用于按需回看长链路工具过程,而不是一直把它们塞进上下文。",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tool_name": {"type": "string", "description": "按工具名精确过滤。"},
|
|
"keyword": {"type": "string", "description": "按参数、思考、结果中的关键词过滤。"},
|
|
"limit_turns": {"type": "integer", "minimum": 1, "maximum": 20},
|
|
"include_empty_turns": {"type": "boolean"},
|
|
},
|
|
"additionalProperties": False,
|
|
},
|
|
handler=_tool_trace_query,
|
|
)
|
|
return registry
|
|
|
|
|
|
def _current_time(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
return {"success": True, "current_time": datetime.now().isoformat()}
|
|
|
|
|
|
def _list_directory(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
path = _safe_workspace_path(ctx, str(args.get("path") or "."))
|
|
if not path.exists():
|
|
return {"success": False, "error": f"Path does not exist: {path}"}
|
|
if not path.is_dir():
|
|
return {"success": False, "error": f"Path is not a directory: {path}"}
|
|
items = []
|
|
for item in sorted(path.iterdir(), key=lambda entry: (not entry.is_dir(), entry.name.lower())):
|
|
stat = item.stat()
|
|
items.append(
|
|
{
|
|
"name": item.name,
|
|
"path": str(item.relative_to(ctx.workspace)),
|
|
"type": "directory" if item.is_dir() else "file",
|
|
"size": stat.st_size if item.is_file() else None,
|
|
"modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(timespec="seconds"),
|
|
}
|
|
)
|
|
return {"success": True, "path": str(path), "entries": items}
|
|
|
|
|
|
def _search_files(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
base = _safe_workspace_path(ctx, str(args.get("path") or "."))
|
|
if not base.exists():
|
|
return {"success": False, "error": f"Path does not exist: {base}"}
|
|
if not base.is_dir():
|
|
return {"success": False, "error": f"Path is not a directory: {base}"}
|
|
|
|
pattern = str(args["pattern"])
|
|
limit = int(args.get("limit") or 50)
|
|
matches: list[dict[str, Any]] = []
|
|
for item in base.rglob("*"):
|
|
if len(matches) >= limit:
|
|
break
|
|
if fnmatch.fnmatch(item.name, pattern):
|
|
matches.append(
|
|
{
|
|
"name": item.name,
|
|
"path": str(item.relative_to(ctx.workspace)),
|
|
"type": "directory" if item.is_dir() else "file",
|
|
}
|
|
)
|
|
return {
|
|
"success": True,
|
|
"base_path": str(base),
|
|
"pattern": pattern,
|
|
"matches": matches,
|
|
"truncated": len(matches) >= limit,
|
|
}
|
|
|
|
|
|
def _read_file(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
path = _safe_workspace_path(ctx, str(args["path"]))
|
|
if not path.exists():
|
|
return {"success": False, "error": f"File does not exist: {path}"}
|
|
if not path.is_file():
|
|
return {"success": False, "error": f"Path is not a file: {path}"}
|
|
return {"success": True, "path": str(path), "content": path.read_text(encoding="utf-8")}
|
|
|
|
|
|
def _write_file(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
path = _safe_workspace_path(ctx, str(args["path"]))
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
content = str(args["content"])
|
|
path.write_text(content, encoding="utf-8")
|
|
return {"success": True, "path": str(path), "bytes_written": len(content.encode("utf-8"))}
|
|
|
|
|
|
def _execute_shell(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
timeout = int(args.get("timeout_seconds") or 20)
|
|
command = _normalize_shell_command(str(args["command"]))
|
|
try:
|
|
proc = subprocess.run(
|
|
_default_shell_command(command),
|
|
cwd=str(ctx.workspace),
|
|
capture_output=True,
|
|
text=True,
|
|
encoding="utf-8",
|
|
errors="replace",
|
|
timeout=timeout,
|
|
stdin=subprocess.DEVNULL,
|
|
env=_subprocess_env(),
|
|
shell=False,
|
|
)
|
|
except subprocess.TimeoutExpired as exc:
|
|
return {
|
|
"success": False,
|
|
"command": command,
|
|
"returncode": None,
|
|
"stdout": _trim_output(_coerce_subprocess_text(exc.stdout)),
|
|
"stderr": _trim_output(_coerce_subprocess_text(exc.stderr)),
|
|
"error": f"Command timed out after {timeout} seconds.",
|
|
}
|
|
return {
|
|
"success": proc.returncode == 0,
|
|
"command": command,
|
|
"returncode": proc.returncode,
|
|
"stdout": _trim_output(proc.stdout),
|
|
"stderr": _trim_output(proc.stderr),
|
|
}
|
|
|
|
|
|
def _run_python(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
timeout = int(args.get("timeout_seconds") or 20)
|
|
try:
|
|
proc = subprocess.run(
|
|
[sys.executable, "-c", str(args["code"])],
|
|
cwd=str(ctx.workspace),
|
|
capture_output=True,
|
|
text=True,
|
|
encoding="utf-8",
|
|
errors="replace",
|
|
timeout=timeout,
|
|
stdin=subprocess.DEVNULL,
|
|
env=_subprocess_env(),
|
|
shell=False,
|
|
)
|
|
except subprocess.TimeoutExpired as exc:
|
|
return {
|
|
"success": False,
|
|
"returncode": None,
|
|
"stdout": _trim_output(_coerce_subprocess_text(exc.stdout)),
|
|
"stderr": _trim_output(_coerce_subprocess_text(exc.stderr)),
|
|
"error": f"Python snippet timed out after {timeout} seconds.",
|
|
}
|
|
return {
|
|
"success": proc.returncode == 0,
|
|
"returncode": proc.returncode,
|
|
"stdout": _trim_output(proc.stdout),
|
|
"stderr": _trim_output(proc.stderr),
|
|
}
|
|
|
|
|
|
def _tool_trace_query(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]:
|
|
if ctx.tool_trace_store is None:
|
|
return {"success": False, "error": "tool trace store is not configured"}
|
|
return ctx.tool_trace_store.query(
|
|
tool_name=args.get("tool_name"),
|
|
keyword=args.get("keyword"),
|
|
limit_turns=args.get("limit_turns"),
|
|
include_empty_turns=bool(args.get("include_empty_turns", False)),
|
|
)
|
|
|
|
|
|
def _safe_workspace_path(ctx: ToolContext, value: str) -> Path:
|
|
workspace = ctx.workspace.resolve()
|
|
raw = Path(value)
|
|
candidate = raw.resolve() if raw.is_absolute() else (workspace / raw).resolve()
|
|
if candidate == workspace or workspace in candidate.parents:
|
|
return candidate
|
|
raise ValueError(f"Path escapes workspace: {value}")
|
|
|
|
|
|
def _default_shell_command(command: str) -> list[str]:
|
|
if platform.system().lower().startswith("win"):
|
|
script = (
|
|
"$ErrorActionPreference = 'Stop'; "
|
|
"[Console]::InputEncoding = [System.Text.UTF8Encoding]::new(); "
|
|
"[Console]::OutputEncoding = [System.Text.UTF8Encoding]::new(); "
|
|
"$OutputEncoding = [System.Text.UTF8Encoding]::new(); "
|
|
f"{command}; "
|
|
"if ($null -ne $LASTEXITCODE) { exit $LASTEXITCODE }"
|
|
)
|
|
return ["powershell", "-NoProfile", "-NonInteractive", "-Command", script]
|
|
return ["bash", "-lc", command]
|
|
|
|
|
|
def _normalize_shell_command(command: str) -> str:
|
|
if not platform.system().lower().startswith("win"):
|
|
return command
|
|
if "conda" in command and " run " in f" {command} " and "--no-capture-output" not in command:
|
|
return re.sub(r"\bconda\s+run\b", "conda run --no-capture-output", command, count=1)
|
|
return command
|
|
|
|
|
|
def _subprocess_env() -> dict[str, str]:
|
|
env = os.environ.copy()
|
|
env.setdefault("PYTHONIOENCODING", "utf-8")
|
|
env.setdefault("PYTHONUTF8", "1")
|
|
env.setdefault("CONDA_REPORT_ERRORS", "false")
|
|
return env
|
|
|
|
|
|
def _coerce_subprocess_text(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, bytes):
|
|
return value.decode("utf-8", errors="replace")
|
|
return str(value)
|
|
|
|
|
|
def _trim_output(text: str, limit: int = 12000) -> str:
|
|
if len(text) <= limit:
|
|
return text
|
|
return text[:limit] + "\n...[truncated]"
|