meeting_memory/providers/base.py

59 lines
1.6 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterator, List, Optional
@dataclass(slots=True)
class ToolCall:
id: str
name: str
arguments: Dict[str, Any]
@dataclass(slots=True)
class AssistantTurn:
content: str = ""
reasoning: str = ""
tool_calls: List[ToolCall] = field(default_factory=list)
raw: Any = None
@dataclass(slots=True)
class StreamEvent:
type: str
delta: str = ""
turn: Optional[AssistantTurn] = None
raw: Any = None
class AgentProvider(ABC):
@abstractmethod
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
raise NotImplementedError
def stream_generate(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
should_cancel: Optional[Callable[[], bool]] = None,
) -> Iterator[StreamEvent]:
turn = self.generate(messages, tools)
if turn.reasoning:
for chunk in _chunk_text(turn.reasoning):
yield StreamEvent(type="reasoning", delta=chunk, raw=turn.raw)
if turn.content:
for chunk in _chunk_text(turn.content):
yield StreamEvent(type="content", delta=chunk, raw=turn.raw)
yield StreamEvent(type="turn", turn=turn, raw=turn.raw)
def _chunk_text(text: str, chunk_size: int = 32) -> Iterator[str]:
if not text:
return
start = 0
while start < len(text):
yield text[start : start + chunk_size]
start += chunk_size