Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 78 additions & 20 deletions dreadnode/agent/hooks/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,59 @@ def _get_last_input_tokens(event: AgentEvent) -> int:
return last_generation_event.usage.input_tokens if last_generation_event.usage else 0


def _find_tool_aware_boundary(
messages: list[rg.Message],
min_messages_to_keep: int,
) -> int:
"""
Find the best summarization boundary while preserving tool call/response pairs.

This prevents breaking tool messages that would cause API errors with strict models
(OpenAI, Anthropic) that require every tool_call_id to have a matching response.

Args:
messages: List of messages to analyze (excluding system message)
min_messages_to_keep: Minimum messages that must be kept after boundary

Returns:
Index where to split (messages[:idx] summarized, messages[idx:] kept)
Returns 0 if no valid boundary found
"""
# Build tool_call_id -> assistant message index mapping
tool_call_map: dict[str, int] = {}
for i, msg in enumerate(messages):
if msg.role == "assistant" and hasattr(msg, "tool_calls"):
for tc in getattr(msg, "tool_calls", None) or []:
if hasattr(tc, "id"):
tool_call_map[tc.id] = i

# Walk backward from desired split point to find first valid boundary
for boundary in range(len(messages) - min_messages_to_keep, -1, -1):
# Check if this boundary would orphan any tool responses
has_orphan = False
for msg in messages[boundary:]:
if msg.role == "tool" and hasattr(msg, "tool_call_id"):
tool_call_id = getattr(msg, "tool_call_id", None)
if tool_call_id is not None:
call_idx = tool_call_map.get(tool_call_id)
if call_idx is not None and call_idx < boundary:
has_orphan = True
break

if not has_orphan:
return boundary

return 0 # No valid boundary found


@component
def summarize_when_long(
model: str | rg.Generator | None = None,
max_tokens: int = 100_000,
min_messages_to_keep: int = 5,
guidance: str = "",
*,
preserve_tool_pairs: bool = True,
) -> "Hook":
"""
Creates a hook to manage the agent's context window by summarizing the conversation history.
Expand All @@ -66,6 +113,9 @@ def summarize_when_long(
(default is None, meaning no proactive summarization).
min_messages_to_keep: The minimum number of messages to retain after summarization (default is 5).
guidance: Additional guidance for the summarization process (default is "").
preserve_tool_pairs: If True, ensures tool call/response pairs stay together to avoid breaking
strict API requirements (OpenAI, Anthropic). Defaults to True. Set to False to use legacy
behavior that may break tool pairs but allows more aggressive summarization.
"""

if min_messages_to_keep < 2:
Expand All @@ -91,6 +141,10 @@ async def summarize_when_long( # noqa: PLR0912
guidance,
help="Additional guidance for the summarization process",
),
preserve_tool_pairs: bool = Config(
preserve_tool_pairs,
help="Preserve tool call/response pairs to avoid breaking strict API requirements",
),
) -> Reaction | None:
should_summarize = False

Expand Down Expand Up @@ -123,26 +177,30 @@ async def summarize_when_long( # noqa: PLR0912
messages.pop(0) if messages and messages[0].role == "system" else None
)

# Find the best point to summarize by walking the message list once.
# A boundary is valid after a simple assistant message or a finished tool block.
best_summarize_boundary = 0
for i, message in enumerate(messages):
# If the remaining messages are less than or equal to our minimum, we can't slice any further.
if len(messages) - i <= min_messages_to_keep:
break

# Condition 1: The message is an assistant response without tool calls.
is_simple_assistant = message.role == "assistant" and not getattr(
message, "tool_calls", None
)

# Condition 2: The message is the last in a block of tool responses.
is_last_tool_in_block = message.role == "tool" and (
i + 1 == len(messages) or messages[i + 1].role != "tool"
)

if is_simple_assistant or is_last_tool_in_block:
best_summarize_boundary = i + 1
# Find the best point to summarize
if preserve_tool_pairs:
# Use tool-aware boundary finding to prevent breaking tool call/response pairs
best_summarize_boundary = _find_tool_aware_boundary(messages, min_messages_to_keep)
else:
# Legacy behavior: walk the message list once looking for simple boundaries
best_summarize_boundary = 0
for i, message in enumerate(messages):
# If the remaining messages are less than or equal to our minimum, we can't slice any further.
if len(messages) - i <= min_messages_to_keep:
break

# Condition 1: The message is an assistant response without tool calls.
is_simple_assistant = message.role == "assistant" and not getattr(
message, "tool_calls", None
)

# Condition 2: The message is the last in a block of tool responses.
is_last_tool_in_block = message.role == "tool" and (
i + 1 == len(messages) or messages[i + 1].role != "tool"
)

if is_simple_assistant or is_last_tool_in_block:
best_summarize_boundary = i + 1

if best_summarize_boundary == 0:
return None # No valid slice point was found.
Expand Down
116 changes: 116 additions & 0 deletions tests/test_preserve_tool_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Tests for preserve_tool_pairs functionality in summarize_when_long hook."""

import rigging as rg

from dreadnode.agent.hooks.summarize import _find_tool_aware_boundary


def test_preserves_tool_pairs():
"""Tool call and response stay together when boundary would split them."""
messages = [
rg.Message("user", "Hello"),
rg.Message(
"assistant",
"Let me check",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "check", "arguments": "{}"},
}
],
),
rg.Message("tool", "Result", tool_call_id="call_1"),
rg.Message("assistant", "Done"),
rg.Message("user", "Thanks"),
]

# With min=3, naive boundary would be at index 2, keeping [2,3,4]
# But that would orphan the tool response at index 2 (call at index 1)
# So boundary should move back to index 1 to keep the pair together
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=3)
assert boundary == 1, f"Expected boundary 1 to keep tool pair together, got {boundary}"

# Verify the kept messages include the complete tool pair
kept = messages[boundary:]
assert len(kept) == 4
assert kept[0].role == "assistant"
assert kept[1].role == "tool"


def test_no_tools():
"""Works correctly without any tool messages."""
messages = [
rg.Message("user", "Hello"),
rg.Message("assistant", "Hi"),
rg.Message("user", "How are you"),
rg.Message("assistant", "Good"),
]

boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=2)
assert boundary == 2, "Should split at natural boundary"


def test_multiple_tool_pairs():
"""Handles multiple tool call/response pairs correctly."""
messages = [
rg.Message("user", "Do A and B"),
rg.Message(
"assistant",
"Running A",
tool_calls=[
{"id": "a", "type": "function", "function": {"name": "run_a", "arguments": "{}"}}
],
),
rg.Message("tool", "A done", tool_call_id="a"),
rg.Message(
"assistant",
"Running B",
tool_calls=[
{"id": "b", "type": "function", "function": {"name": "run_b", "arguments": "{}"}}
],
),
rg.Message("tool", "B done", tool_call_id="b"),
rg.Message("user", "Thanks"),
]

# With min=3, naive boundary at index 3 would keep [3,4,5]
# But index 4 (tool response "b") references call at index 3
# So boundary should move back to index 3 to keep the second pair
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=3)
assert boundary == 3, f"Expected boundary 3 to preserve second tool pair, got {boundary}"


def test_no_valid_boundary():
"""Returns 0 when min_messages would force splitting all tool pairs."""
messages = [
rg.Message(
"assistant",
"Start",
tool_calls=[
{"id": "1", "type": "function", "function": {"name": "start", "arguments": "{}"}}
],
),
rg.Message("tool", "Result 1", tool_call_id="1"),
rg.Message(
"assistant",
"Continue",
tool_calls=[
{"id": "2", "type": "function", "function": {"name": "continue", "arguments": "{}"}}
],
),
rg.Message("tool", "Result 2", tool_call_id="2"),
]

# With min=3, we'd need to keep last 3 messages
# Any boundary would orphan at least one tool response
# So should return 0 to keep everything
boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=3)
assert boundary == 0, f"Should keep everything when no valid split exists, got {boundary}"


if __name__ == "__main__":
test_preserves_tool_pairs()
test_no_tools()
test_multiple_tool_pairs()
test_no_valid_boundary()