69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
import uuid
|
||
from typing import Any, Dict, List
|
||
|
||
from .base import AgentProvider, AssistantTurn, ToolCall
|
||
|
||
|
||
class RuleBasedMeetingProvider(AgentProvider):
|
||
"""Offline provider for demos. It still uses the same tool-call loop as an LLM provider."""
|
||
|
||
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
|
||
last = messages[-1]
|
||
if last.get("role") == "tool":
|
||
return AssistantTurn(content=_summarize_tool_result(last.get("name", ""), last.get("content", "{}")))
|
||
user_message = _last_user_message(messages)
|
||
path = _extract_path(user_message)
|
||
if _looks_like_import(user_message) and path:
|
||
return AssistantTurn(
|
||
tool_calls=[ToolCall(id=f"call_{uuid.uuid4().hex}", name="store_meeting_memory", arguments={"file_path": path})]
|
||
)
|
||
return AssistantTurn(
|
||
tool_calls=[
|
||
ToolCall(
|
||
id=f"call_{uuid.uuid4().hex}",
|
||
name="query_meeting_memory",
|
||
arguments={"query": user_message, "top_k": 5},
|
||
)
|
||
]
|
||
)
|
||
|
||
|
||
def _summarize_tool_result(tool_name: str, raw: str) -> str:
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
return raw
|
||
if tool_name == "store_meeting_memory":
|
||
return (
|
||
"已尝试写入会议长期记忆。\n"
|
||
f"- stored: {data.get('stored')}\n"
|
||
f"- archive: {data.get('archive_path')}\n"
|
||
f"- graph_enabled: {data.get('graph_enabled')}"
|
||
)
|
||
if tool_name == "query_meeting_memory":
|
||
return data.get("answer", json.dumps(data, ensure_ascii=False, indent=2))
|
||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def _last_user_message(messages: List[Dict[str, Any]]) -> str:
|
||
for message in reversed(messages):
|
||
if message.get("role") == "user":
|
||
return str(message.get("content", ""))
|
||
return ""
|
||
|
||
|
||
def _looks_like_import(text: str) -> bool:
|
||
return any(word in text for word in ("导入", "读取", "保存", "记录", "归档", "import")) and bool(_extract_path(text))
|
||
|
||
|
||
def _extract_path(text: str) -> str:
|
||
quoted = re.search(r"['\"]([^'\"]+\.(?:txt|md))['\"]", text)
|
||
if quoted:
|
||
return quoted.group(1)
|
||
match = re.search(r"([A-Za-z]:\\[^\s,。;]+?\.(?:txt|md))", text)
|
||
return match.group(1) if match else ""
|