meeting_memory/providers/rule_based.py

69 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
import json
import re
import uuid
from typing import Any, Dict, List
from .base import AgentProvider, AssistantTurn, ToolCall
class RuleBasedMeetingProvider(AgentProvider):
"""Offline provider for demos. It still uses the same tool-call loop as an LLM provider."""
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
last = messages[-1]
if last.get("role") == "tool":
return AssistantTurn(content=_summarize_tool_result(last.get("name", ""), last.get("content", "{}")))
user_message = _last_user_message(messages)
path = _extract_path(user_message)
if _looks_like_import(user_message) and path:
return AssistantTurn(
tool_calls=[ToolCall(id=f"call_{uuid.uuid4().hex}", name="store_meeting_memory", arguments={"file_path": path})]
)
return AssistantTurn(
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex}",
name="query_meeting_memory",
arguments={"query": user_message, "top_k": 5},
)
]
)
def _summarize_tool_result(tool_name: str, raw: str) -> str:
try:
data = json.loads(raw)
except json.JSONDecodeError:
return raw
if tool_name == "store_meeting_memory":
return (
"已尝试写入会议长期记忆。\n"
f"- stored: {data.get('stored')}\n"
f"- archive: {data.get('archive_path')}\n"
f"- graph_enabled: {data.get('graph_enabled')}"
)
if tool_name == "query_meeting_memory":
return data.get("answer", json.dumps(data, ensure_ascii=False, indent=2))
return json.dumps(data, ensure_ascii=False, indent=2)
def _last_user_message(messages: List[Dict[str, Any]]) -> str:
for message in reversed(messages):
if message.get("role") == "user":
return str(message.get("content", ""))
return ""
def _looks_like_import(text: str) -> bool:
return any(word in text for word in ("导入", "读取", "保存", "记录", "归档", "import")) and bool(_extract_path(text))
def _extract_path(text: str) -> str:
quoted = re.search(r"['\"]([^'\"]+\.(?:txt|md))['\"]", text)
if quoted:
return quoted.group(1)
match = re.search(r"([A-Za-z]:\\[^\s]+?\.(?:txt|md))", text)
return match.group(1) if match else ""