Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions dreadnode/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_total_usage_from_events,
)
from dreadnode.agent.hooks import Hook, retry_with_feedback
from dreadnode.agent.hooks.notification import NotificationBackend, TerminalNotificationBackend
from dreadnode.agent.reactions import (
Continue,
Fail,
Expand Down Expand Up @@ -62,6 +63,14 @@
CommitBehavior = t.Literal["always", "on-success"]


async def _safe_send(backend: NotificationBackend, event: AgentEvent, message: str) -> None:
"""Send notification with error handling."""
try:
await backend.send(event, message)
except Exception: # noqa: BLE001
logger.exception(f"Notification failed for {event.__class__.__name__}")


class AgentWarning(UserWarning):
"""Warning raised when an agent is used in a way that may not be safe or intended."""

Expand Down Expand Up @@ -111,6 +120,24 @@ class Agent(Model):
assert_scores: list[str] | t.Literal[True] = Field(default_factory=list)
"""Scores to ensure are truthy, otherwise the agent task is marked as failed."""

notifications: t.Annotated[bool | NotificationBackend | None, SkipValidation] = Config(
default=None, repr=False
)
"""
Enable notifications.
- True: Uses TerminalNotificationBackend (stderr output)
- NotificationBackend instance: Uses custom backend
- None/False: Disabled
"""
notification_events: list[type[AgentEvent]] | t.Literal["all"] = Config(
default="all", repr=False
)
"""Which event types to notify on. Defaults to all events."""
notification_formatter: t.Annotated[t.Callable[[AgentEvent], str] | None, SkipValidation] = (
Config(default=None, repr=False)
)
"""Custom formatter for notification messages. If None, uses event's default representation."""

_generator: rg.Generator | None = PrivateAttr(None, init=False)

@field_validator("tools", mode="before")
Expand All @@ -129,6 +156,49 @@ def validate_tools(cls, value: t.Any) -> t.Any:

return tools

def model_post_init(self, context: t.Any) -> None:
super().model_post_init(context)

# Auto-inject notification hook if enabled
if self.notifications:
backend = (
self.notifications
if isinstance(self.notifications, NotificationBackend)
else TerminalNotificationBackend()
)

self.hooks.append(
self._create_notification_hook(
backend,
self.notification_events,
self.notification_formatter,
)
)

def _create_notification_hook(
self,
backend: NotificationBackend,
events: list[type[AgentEvent]] | t.Literal["all"],
formatter: t.Callable[[AgentEvent], str] | None,
) -> Hook:
"""Create a notification hook that delegates formatting to events."""
import asyncio

async def notification_hook(event: AgentEvent) -> None:
# Filter events
if events != "all" and not any(isinstance(event, et) for et in events):
return

# Use custom formatter if provided, otherwise delegate to event
message = formatter(event) if formatter else event.format_notification()

# Fire and forget - don't block agent execution
_ = asyncio.create_task(_safe_send(backend, event, message)) # noqa: RUF006

return

return notification_hook

def __repr__(self) -> str:
description = shorten_string(self.description or "", 50)

Expand Down
31 changes: 31 additions & 0 deletions dreadnode/agent/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,22 @@ def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
border_style="dim",
)

def format_notification(self) -> str:
"""
Format this event as a human-readable notification message.
Override in subclasses for custom formatting.
"""
return f"{self.__class__.__name__}"

def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
yield self.format_as_panel()


@dataclass
class AgentStart(AgentEvent):
def format_notification(self) -> str:
return f"Starting agent: {self.agent.name}"

def format_as_panel(self, *, truncate: bool = False) -> Panel:
return Panel(
format_message(self.messages[0], truncate=truncate),
Expand Down Expand Up @@ -158,6 +168,10 @@ def __repr__(self) -> str:
message = f"Message(role={self.message.role}, content='{message_content}', tool_calls={tool_call_count})"
return f"GenerationEnd(message={message})"

def format_notification(self) -> str:
tokens = self.usage.total_tokens if self.usage else "unknown"
return f"Generation complete ({tokens} tokens)"

def format_as_panel(self, *, truncate: bool = False) -> Panel:
cost = round(self.estimated_cost, 6) if self.estimated_cost else ""
usage = str(self.usage) or ""
Expand All @@ -173,6 +187,9 @@ def format_as_panel(self, *, truncate: bool = False) -> Panel:

@dataclass
class AgentStalled(AgentEventInStep):
def format_notification(self) -> str:
return "Agent stalled: no tool calls and no stop conditions met"

def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
return Panel(
Text(
Expand All @@ -189,6 +206,9 @@ def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
class AgentError(AgentEventInStep):
error: BaseException

def format_notification(self) -> str:
return f"Error: {self.error.__class__.__name__}: {self.error!s}"

def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
return Panel(
repr(self),
Expand All @@ -205,6 +225,9 @@ class ToolStart(AgentEventInStep):
def __repr__(self) -> str:
return f"ToolStart(tool_call={self.tool_call})"

def format_notification(self) -> str:
return f"Starting tool: {self.tool_call.name}"

def format_as_panel(self, *, truncate: bool = False) -> Panel:
content: RenderableType
try:
Expand Down Expand Up @@ -245,6 +268,10 @@ def __repr__(self) -> str:
message = f"Message(role={self.message.role}, content='{message_content}')"
return f"ToolEnd(tool_call={self.tool_call}, message={message}, stop={self.stop})"

def format_notification(self) -> str:
status = " (requesting stop)" if self.stop else ""
return f"Finished tool: {self.tool_call.name}{status}"

def format_as_panel(self, *, truncate: bool = False) -> Panel:
panel = format_message(self.message, truncate=truncate)
subtitle = f"[dim]{self.tool_call.id}[/dim]"
Expand Down Expand Up @@ -294,6 +321,10 @@ class AgentEnd(AgentEvent):
stop_reason: "AgentStopReason"
result: "AgentResult"

def format_notification(self) -> str:
status = "Failed" if self.result.failed else "Finished"
return f"{status}: {self.stop_reason} (steps: {self.result.steps}, tokens: {self.result.usage.total_tokens})"

def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
res = self.result
status = "[bold red]Failed[/bold red]" if res.failed else "[bold green]Success[/bold green]"
Expand Down
12 changes: 12 additions & 0 deletions dreadnode/agent/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@
retry_with_feedback,
)
from dreadnode.agent.hooks.metrics import tool_metrics
from dreadnode.agent.hooks.notification import (
LogNotificationBackend,
NotificationBackend,
TerminalNotificationBackend,
WebhookNotificationBackend,
notify,
)
from dreadnode.agent.hooks.summarize import summarize_when_long

__all__ = [
"Hook",
"LogNotificationBackend",
"NotificationBackend",
"TerminalNotificationBackend",
"WebhookNotificationBackend",
"backoff_on_error",
"backoff_on_ratelimit",
"notify",
"retry_with_feedback",
"summarize_when_long",
"tool_metrics",
Expand Down
130 changes: 130 additions & 0 deletions dreadnode/agent/hooks/notification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import typing as t
from abc import ABC, abstractmethod

from loguru import logger

if t.TYPE_CHECKING:
import httpx

from dreadnode.agent.events import AgentEvent


class NotificationBackend(ABC):
@abstractmethod
async def send(self, event: "AgentEvent", message: str) -> None:
"""Send a notification for the given event."""


class LogNotificationBackend(NotificationBackend):
async def send(self, event: "AgentEvent", message: str) -> None:
logger.info(f"[{event.agent.name}] {message}")


class TerminalNotificationBackend(NotificationBackend):
async def send(self, event: "AgentEvent", message: str) -> None:
import sys

print(f"[{event.agent.name}] {message}", file=sys.stderr)


class WebhookNotificationBackend(NotificationBackend):
def __init__(self, url: str, headers: dict[str, str] | None = None, timeout: float = 5.0):
self.url = url
self.headers = headers or {}
self.timeout = timeout
self._client: httpx.AsyncClient | None = None

async def __aenter__(self) -> "WebhookNotificationBackend":
import httpx

self._client = httpx.AsyncClient(timeout=self.timeout)
return self

async def __aexit__(self, *args: object) -> None:
if self._client:
await self._client.aclose()

async def send(self, event: "AgentEvent", message: str) -> None:
import httpx

if not self._client:
self._client = httpx.AsyncClient(timeout=self.timeout)

payload = self._build_payload(event, message)
await self._client.post(self.url, json=payload, headers=self.headers)

def _build_payload(self, event: "AgentEvent", message: str) -> dict[str, str]:
"""Override this to customize webhook payload."""
return {
"agent": event.agent.name,
"event": event.__class__.__name__,
"message": message,
"timestamp": event.timestamp.isoformat(),
}


def notify(
event_type: "type[AgentEvent] | t.Callable[[AgentEvent], bool]",
message: str | t.Callable[["AgentEvent"], str] | None = None,
backend: NotificationBackend | None = None,
) -> t.Callable[["AgentEvent"], t.Awaitable[None]]:
"""
Create a notification hook that sends notifications when events occur.

Unlike other hooks, notification hooks don't affect agent execution - they return
None (no reaction) and run asynchronously to deliver notifications.

Args:
event_type: Event type to trigger on, or predicate function
message: Static message or callable that generates message from event.
If None, uses event.format_notification()
backend: Notification backend (defaults to terminal output)

Returns:
Hook that sends notifications

Example:
```python
from dreadnode.agent import Agent
from dreadnode.agent.events import ToolStart
from dreadnode.agent.hooks.notification import notify

agent = Agent(
name="analyzer",
hooks=[
notify(ToolStart), # Uses default formatting
notify(
ToolStart,
lambda e: f"Starting tool: {e.tool_name}",
),
],
)
```
"""
notification_backend = backend or TerminalNotificationBackend()

async def notification_hook(event: "AgentEvent") -> None:
should_notify = False

if isinstance(event_type, type):
should_notify = isinstance(event, event_type)
elif callable(event_type):
should_notify = event_type(event)

if not should_notify:
return

# Use custom message if provided, otherwise delegate to event
if message is None:
msg = event.format_notification()
else:
msg = message(event) if callable(message) else message

try:
await notification_backend.send(event, msg)
except Exception: # noqa: BLE001
logger.exception("Notification hook failed")

return

return notification_hook
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,9 @@ skip-magic-trailing-comma = false
"dreadnode/transforms/language.py" = [
"RUF001", # intentional use of ambiguous unicode characters for airt
]
"dreadnode/agent/tools/interaction.py" = [
"T201", # print required for user interaction
]
"dreadnode/agent/hooks/notification.py" = [
"T201", # print required for terminal notifications
]
Loading