meeting_memory/providers/rule_based.py

104 lines
4.1 KiB
Python
Raw Normal View History

2026-06-24 07:05:19 +00:00
from __future__ import annotations
import json
import re
import uuid
from typing import Any, Dict, List
from .base import AgentProvider, AssistantTurn, ToolCall
class RuleBasedMeetingProvider(AgentProvider):
"""Offline provider for demos. It still uses the same tool-call loop as an LLM provider."""
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
last = messages[-1]
if last.get("role") == "tool":
return AssistantTurn(content=_summarize_tool_result(last.get("name", ""), last.get("content", "{}")))
user_message = _last_user_message(messages)
knowledge_base_id = _extract_knowledge_base_id(user_message)
path = _extract_path(user_message)
if _looks_like_list_bases(user_message):
return AssistantTurn(
tool_calls=[ToolCall(id=f"call_{uuid.uuid4().hex}", name="list_knowledge_bases", arguments={})]
)
if _looks_like_import(user_message) and path:
args: dict[str, Any] = {"file_path": path}
if knowledge_base_id:
args["knowledge_base_id"] = knowledge_base_id
meeting_date = _extract_date(user_message)
if meeting_date:
args["meeting_date"] = meeting_date
return AssistantTurn(
tool_calls=[ToolCall(id=f"call_{uuid.uuid4().hex}", name="import_meeting_transcript", arguments=args)]
)
query_args: dict[str, Any] = {"query": user_message, "limit": 8}
if knowledge_base_id:
query_args["knowledge_base_id"] = knowledge_base_id
return AssistantTurn(
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex}",
name="query_knowledge",
arguments=query_args,
)
]
)
def _summarize_tool_result(tool_name: str, raw: str) -> str:
try:
data = json.loads(raw)
except json.JSONDecodeError:
return raw
if tool_name == "import_meeting_transcript":
counts = data.get("counts", {})
return (
"已导入会议并更新知识库周会台账。\n"
f"- knowledge_base_id: {data.get('knowledge_base_id') or data.get('team_id')}\n"
f"- meeting_id: {data.get('meeting_id')}\n"
f"- ledger: {data.get('ledger_path')}\n"
f"- 抽取:指标 {counts.get('metrics', 0)},待办 {counts.get('actions', 0)},风险 {counts.get('risks', 0)}"
f"完成事项 {counts.get('completed', 0)},追踪项 {counts.get('followups', 0)}"
)
if tool_name == "query_knowledge":
return data.get("answer", json.dumps(data, ensure_ascii=False, indent=2))
if tool_name == "list_knowledge_bases":
bases = data.get("knowledge_bases", [])
return "当前已有知识库:" + ("".join(bases) if bases else "暂无")
return json.dumps(data, ensure_ascii=False, indent=2)
def _last_user_message(messages: List[Dict[str, Any]]) -> str:
for message in reversed(messages):
if message.get("role") == "user":
return str(message.get("content", ""))
return ""
def _looks_like_import(text: str) -> bool:
return any(word in text for word in ("导入", "读取", "更新", "import")) and bool(_extract_path(text))
def _extract_path(text: str) -> str:
quoted = re.search(r"['\"]([^'\"]+\.(?:txt|md))['\"]", text)
if quoted:
return quoted.group(1)
match = re.search(r"([A-Za-z]:\\[^\s]+?\.(?:txt|md))", text)
return match.group(1) if match else ""
def _extract_knowledge_base_id(text: str) -> str | None:
match = re.search(r"(?:knowledge_base_id|知识库|库|team_id|团队|team)[:= ]*([A-Za-z0-9_\-\u4e00-\u9fff]+)", text)
return match.group(1) if match else None
def _extract_date(text: str) -> str:
match = re.search(r"\d{4}-\d{2}-\d{2}", text)
return match.group(0) if match else ""
def _looks_like_list_bases(text: str) -> bool:
return any(word in text for word in ("有哪些知识库", "知识库列表", "列出知识库", "可用知识库", "有哪些库"))