diff --git a/dreadnode/agent/hooks/summarize.py b/dreadnode/agent/hooks/summarize.py index 8999d25..538261c 100644 --- a/dreadnode/agent/hooks/summarize.py +++ b/dreadnode/agent/hooks/summarize.py @@ -44,12 +44,59 @@ def _get_last_input_tokens(event: AgentEvent) -> int: return last_generation_event.usage.input_tokens if last_generation_event.usage else 0 +def _find_tool_aware_boundary( + messages: list[rg.Message], + min_messages_to_keep: int, +) -> int: + """ + Find the best summarization boundary while preserving tool call/response pairs. + + This prevents breaking tool messages that would cause API errors with strict models + (OpenAI, Anthropic) that require every tool_call_id to have a matching response. + + Args: + messages: List of messages to analyze (excluding system message) + min_messages_to_keep: Minimum messages that must be kept after boundary + + Returns: + Index where to split (messages[:idx] summarized, messages[idx:] kept) + Returns 0 if no valid boundary found + """ + # Build tool_call_id -> assistant message index mapping + tool_call_map: dict[str, int] = {} + for i, msg in enumerate(messages): + if msg.role == "assistant" and hasattr(msg, "tool_calls"): + for tc in getattr(msg, "tool_calls", None) or []: + if hasattr(tc, "id"): + tool_call_map[tc.id] = i + + # Walk backward from desired split point to find first valid boundary + for boundary in range(len(messages) - min_messages_to_keep, -1, -1): + # Check if this boundary would orphan any tool responses + has_orphan = False + for msg in messages[boundary:]: + if msg.role == "tool" and hasattr(msg, "tool_call_id"): + tool_call_id = getattr(msg, "tool_call_id", None) + if tool_call_id is not None: + call_idx = tool_call_map.get(tool_call_id) + if call_idx is not None and call_idx < boundary: + has_orphan = True + break + + if not has_orphan: + return boundary + + return 0 # No valid boundary found + + @component def summarize_when_long( model: str | rg.Generator | None = None, max_tokens: int = 100_000, min_messages_to_keep: int = 5, guidance: str = "", + *, + preserve_tool_pairs: bool = True, ) -> "Hook": """ Creates a hook to manage the agent's context window by summarizing the conversation history. @@ -66,6 +113,9 @@ def summarize_when_long( (default is None, meaning no proactive summarization). min_messages_to_keep: The minimum number of messages to retain after summarization (default is 5). guidance: Additional guidance for the summarization process (default is ""). + preserve_tool_pairs: If True, ensures tool call/response pairs stay together to avoid breaking + strict API requirements (OpenAI, Anthropic). Defaults to True. Set to False to use legacy + behavior that may break tool pairs but allows more aggressive summarization. """ if min_messages_to_keep < 2: @@ -91,6 +141,10 @@ async def summarize_when_long( # noqa: PLR0912 guidance, help="Additional guidance for the summarization process", ), + preserve_tool_pairs: bool = Config( + preserve_tool_pairs, + help="Preserve tool call/response pairs to avoid breaking strict API requirements", + ), ) -> Reaction | None: should_summarize = False @@ -123,26 +177,30 @@ async def summarize_when_long( # noqa: PLR0912 messages.pop(0) if messages and messages[0].role == "system" else None ) - # Find the best point to summarize by walking the message list once. - # A boundary is valid after a simple assistant message or a finished tool block. - best_summarize_boundary = 0 - for i, message in enumerate(messages): - # If the remaining messages are less than or equal to our minimum, we can't slice any further. - if len(messages) - i <= min_messages_to_keep: - break - - # Condition 1: The message is an assistant response without tool calls. - is_simple_assistant = message.role == "assistant" and not getattr( - message, "tool_calls", None - ) - - # Condition 2: The message is the last in a block of tool responses. - is_last_tool_in_block = message.role == "tool" and ( - i + 1 == len(messages) or messages[i + 1].role != "tool" - ) - - if is_simple_assistant or is_last_tool_in_block: - best_summarize_boundary = i + 1 + # Find the best point to summarize + if preserve_tool_pairs: + # Use tool-aware boundary finding to prevent breaking tool call/response pairs + best_summarize_boundary = _find_tool_aware_boundary(messages, min_messages_to_keep) + else: + # Legacy behavior: walk the message list once looking for simple boundaries + best_summarize_boundary = 0 + for i, message in enumerate(messages): + # If the remaining messages are less than or equal to our minimum, we can't slice any further. + if len(messages) - i <= min_messages_to_keep: + break + + # Condition 1: The message is an assistant response without tool calls. + is_simple_assistant = message.role == "assistant" and not getattr( + message, "tool_calls", None + ) + + # Condition 2: The message is the last in a block of tool responses. + is_last_tool_in_block = message.role == "tool" and ( + i + 1 == len(messages) or messages[i + 1].role != "tool" + ) + + if is_simple_assistant or is_last_tool_in_block: + best_summarize_boundary = i + 1 if best_summarize_boundary == 0: return None # No valid slice point was found. diff --git a/tests/test_preserve_tool_pairs.py b/tests/test_preserve_tool_pairs.py new file mode 100644 index 0000000..9ba23c3 --- /dev/null +++ b/tests/test_preserve_tool_pairs.py @@ -0,0 +1,116 @@ +"""Tests for preserve_tool_pairs functionality in summarize_when_long hook.""" + +import rigging as rg + +from dreadnode.agent.hooks.summarize import _find_tool_aware_boundary + + +def test_preserves_tool_pairs(): + """Tool call and response stay together when boundary would split them.""" + messages = [ + rg.Message("user", "Hello"), + rg.Message( + "assistant", + "Let me check", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "check", "arguments": "{}"}, + } + ], + ), + rg.Message("tool", "Result", tool_call_id="call_1"), + rg.Message("assistant", "Done"), + rg.Message("user", "Thanks"), + ] + + # With min=3, naive boundary would be at index 2, keeping [2,3,4] + # But that would orphan the tool response at index 2 (call at index 1) + # So boundary should move back to index 1 to keep the pair together + boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=3) + assert boundary == 1, f"Expected boundary 1 to keep tool pair together, got {boundary}" + + # Verify the kept messages include the complete tool pair + kept = messages[boundary:] + assert len(kept) == 4 + assert kept[0].role == "assistant" + assert kept[1].role == "tool" + + +def test_no_tools(): + """Works correctly without any tool messages.""" + messages = [ + rg.Message("user", "Hello"), + rg.Message("assistant", "Hi"), + rg.Message("user", "How are you"), + rg.Message("assistant", "Good"), + ] + + boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=2) + assert boundary == 2, "Should split at natural boundary" + + +def test_multiple_tool_pairs(): + """Handles multiple tool call/response pairs correctly.""" + messages = [ + rg.Message("user", "Do A and B"), + rg.Message( + "assistant", + "Running A", + tool_calls=[ + {"id": "a", "type": "function", "function": {"name": "run_a", "arguments": "{}"}} + ], + ), + rg.Message("tool", "A done", tool_call_id="a"), + rg.Message( + "assistant", + "Running B", + tool_calls=[ + {"id": "b", "type": "function", "function": {"name": "run_b", "arguments": "{}"}} + ], + ), + rg.Message("tool", "B done", tool_call_id="b"), + rg.Message("user", "Thanks"), + ] + + # With min=3, naive boundary at index 3 would keep [3,4,5] + # But index 4 (tool response "b") references call at index 3 + # So boundary should move back to index 3 to keep the second pair + boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=3) + assert boundary == 3, f"Expected boundary 3 to preserve second tool pair, got {boundary}" + + +def test_no_valid_boundary(): + """Returns 0 when min_messages would force splitting all tool pairs.""" + messages = [ + rg.Message( + "assistant", + "Start", + tool_calls=[ + {"id": "1", "type": "function", "function": {"name": "start", "arguments": "{}"}} + ], + ), + rg.Message("tool", "Result 1", tool_call_id="1"), + rg.Message( + "assistant", + "Continue", + tool_calls=[ + {"id": "2", "type": "function", "function": {"name": "continue", "arguments": "{}"}} + ], + ), + rg.Message("tool", "Result 2", tool_call_id="2"), + ] + + # With min=3, we'd need to keep last 3 messages + # Any boundary would orphan at least one tool response + # So should return 0 to keep everything + boundary = _find_tool_aware_boundary(messages, min_messages_to_keep=3) + assert boundary == 0, f"Should keep everything when no valid split exists, got {boundary}" + + +if __name__ == "__main__": + test_preserves_tool_pairs() + test_no_tools() + test_multiple_tool_pairs() + test_no_valid_boundary()