meeting_memory/providers/rule_based.py

104 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
import json
import re
import uuid
from typing import Any, Dict, List
from .base import AgentProvider, AssistantTurn, ToolCall
class RuleBasedMeetingProvider(AgentProvider):
"""Offline provider for demos. It still uses the same tool-call loop as an LLM provider."""
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
last = messages[-1]
if last.get("role") == "tool":
return AssistantTurn(content=_summarize_tool_result(last.get("name", ""), last.get("content", "{}")))
user_message = _last_user_message(messages)
knowledge_base_id = _extract_knowledge_base_id(user_message)
path = _extract_path(user_message)
if _looks_like_list_bases(user_message):
return AssistantTurn(
tool_calls=[ToolCall(id=f"call_{uuid.uuid4().hex}", name="list_knowledge_bases", arguments={})]
)
if _looks_like_import(user_message) and path:
args: dict[str, Any] = {"file_path": path}
if knowledge_base_id:
args["knowledge_base_id"] = knowledge_base_id
meeting_date = _extract_date(user_message)
if meeting_date:
args["meeting_date"] = meeting_date
return AssistantTurn(
tool_calls=[ToolCall(id=f"call_{uuid.uuid4().hex}", name="import_meeting_transcript", arguments=args)]
)
query_args: dict[str, Any] = {"query": user_message, "limit": 8}
if knowledge_base_id:
query_args["knowledge_base_id"] = knowledge_base_id
return AssistantTurn(
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex}",
name="query_knowledge",
arguments=query_args,
)
]
)
def _summarize_tool_result(tool_name: str, raw: str) -> str:
try:
data = json.loads(raw)
except json.JSONDecodeError:
return raw
if tool_name == "import_meeting_transcript":
counts = data.get("counts", {})
return (
"已导入会议并更新知识库周会台账。\n"
f"- knowledge_base_id: {data.get('knowledge_base_id') or data.get('team_id')}\n"
f"- meeting_id: {data.get('meeting_id')}\n"
f"- ledger: {data.get('ledger_path')}\n"
f"- 抽取:指标 {counts.get('metrics', 0)},待办 {counts.get('actions', 0)},风险 {counts.get('risks', 0)}"
f"完成事项 {counts.get('completed', 0)},追踪项 {counts.get('followups', 0)}"
)
if tool_name == "query_knowledge":
return data.get("answer", json.dumps(data, ensure_ascii=False, indent=2))
if tool_name == "list_knowledge_bases":
bases = data.get("knowledge_bases", [])
return "当前已有知识库:" + ("".join(bases) if bases else "暂无")
return json.dumps(data, ensure_ascii=False, indent=2)
def _last_user_message(messages: List[Dict[str, Any]]) -> str:
for message in reversed(messages):
if message.get("role") == "user":
return str(message.get("content", ""))
return ""
def _looks_like_import(text: str) -> bool:
return any(word in text for word in ("导入", "读取", "更新", "import")) and bool(_extract_path(text))
def _extract_path(text: str) -> str:
quoted = re.search(r"['\"]([^'\"]+\.(?:txt|md))['\"]", text)
if quoted:
return quoted.group(1)
match = re.search(r"([A-Za-z]:\\[^\s]+?\.(?:txt|md))", text)
return match.group(1) if match else ""
def _extract_knowledge_base_id(text: str) -> str | None:
match = re.search(r"(?:knowledge_base_id|知识库|库|team_id|团队|team)[:= ]*([A-Za-z0-9_\-\u4e00-\u9fff]+)", text)
return match.group(1) if match else None
def _extract_date(text: str) -> str:
match = re.search(r"\d{4}-\d{2}-\d{2}", text)
return match.group(0) if match else ""
def _looks_like_list_bases(text: str) -> bool:
return any(word in text for word in ("有哪些知识库", "知识库列表", "列出知识库", "可用知识库", "有哪些库"))