59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any, Callable, Dict, Iterator, List, Optional
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(slots=True)
|
|||
|
|
class ToolCall:
|
|||
|
|
id: str
|
|||
|
|
name: str
|
|||
|
|
arguments: Dict[str, Any]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(slots=True)
|
|||
|
|
class AssistantTurn:
|
|||
|
|
content: str = ""
|
|||
|
|
reasoning: str = ""
|
|||
|
|
tool_calls: List[ToolCall] = field(default_factory=list)
|
|||
|
|
raw: Any = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(slots=True)
|
|||
|
|
class StreamEvent:
|
|||
|
|
type: str
|
|||
|
|
delta: str = ""
|
|||
|
|
turn: Optional[AssistantTurn] = None
|
|||
|
|
raw: Any = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentProvider(ABC):
|
|||
|
|
@abstractmethod
|
|||
|
|
def generate(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]]) -> AssistantTurn:
|
|||
|
|
raise NotImplementedError
|
|||
|
|
|
|||
|
|
def stream_generate(
|
|||
|
|
self,
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
tools: List[Dict[str, Any]],
|
|||
|
|
should_cancel: Optional[Callable[[], bool]] = None,
|
|||
|
|
) -> Iterator[StreamEvent]:
|
|||
|
|
turn = self.generate(messages, tools)
|
|||
|
|
if turn.reasoning:
|
|||
|
|
for chunk in _chunk_text(turn.reasoning):
|
|||
|
|
yield StreamEvent(type="reasoning", delta=chunk, raw=turn.raw)
|
|||
|
|
if turn.content:
|
|||
|
|
for chunk in _chunk_text(turn.content):
|
|||
|
|
yield StreamEvent(type="content", delta=chunk, raw=turn.raw)
|
|||
|
|
yield StreamEvent(type="turn", turn=turn, raw=turn.raw)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _chunk_text(text: str, chunk_size: int = 32) -> Iterator[str]:
|
|||
|
|
if not text:
|
|||
|
|
return
|
|||
|
|
start = 0
|
|||
|
|
while start < len(text):
|
|||
|
|
yield text[start : start + chunk_size]
|
|||
|
|
start += chunk_size
|