diff --git a/dreadnode/agent/hooks/__init__.py b/dreadnode/agent/hooks/__init__.py index 24e5e9e..c1b7c9b 100644 --- a/dreadnode/agent/hooks/__init__.py +++ b/dreadnode/agent/hooks/__init__.py @@ -11,6 +11,7 @@ WebhookNotificationBackend, notify, ) +from dreadnode.agent.hooks.ralph import ralph_hook from dreadnode.agent.hooks.summarize import summarize_when_long __all__ = [ @@ -22,6 +23,7 @@ "backoff_on_error", "backoff_on_ratelimit", "notify", + "ralph_hook", "retry_with_feedback", "summarize_when_long", "tool_metrics", diff --git a/dreadnode/agent/hooks/ralph.py b/dreadnode/agent/hooks/ralph.py new file mode 100644 index 0000000..2593afd --- /dev/null +++ b/dreadnode/agent/hooks/ralph.py @@ -0,0 +1,206 @@ +import typing as t +from dataclasses import dataclass, field + +from loguru import logger + +from dreadnode.agent.events import AgentEvent, GenerationEnd, StepStart +from dreadnode.agent.reactions import Fail, Finish, Reaction, RetryWithFeedback +from dreadnode.scorers import Scorer, ScorerCallable, avg + +if t.TYPE_CHECKING: + from ulid import ULID + + from dreadnode.agent.hooks.base import Hook + + +@dataclass +class RalphState: + """ + Tracks the state of a Ralph iteration loop for a single agent session. + + Attributes: + iteration: Current iteration number (1-indexed). + score_history: List of average scores from each iteration. + last_step_seen: The last step number observed (for detecting new steps). + """ + + iteration: int = 0 + score_history: list[float] = field(default_factory=list) + last_step_seen: int = -1 + + def reset(self, step: int = -1) -> None: + """Reset state to initial values.""" + self.iteration = 0 + self.score_history = [] + self.last_step_seen = step + + +def _is_completion_attempt(message: t.Any) -> bool: + """ + Check if a generation looks like a completion attempt. + + A generation is considered a completion attempt if it has text content + and no tool calls (meaning the agent is trying to provide a final answer). + """ + # Check if message has content and no tool calls + has_content = message.content and isinstance(message.content, str) + has_no_tools = not message.tool_calls or len(message.tool_calls) == 0 + + return has_content and has_no_tools + + +async def _score_output( + output: str, + scorer: Scorer[t.Any], + iteration: int, +) -> float: + """Score the output using the composed scorer.""" + try: + metric = await scorer(output) + score_value = float(metric.value) + logger.debug(f"Ralph iteration {iteration}: score = {score_value:.3f}") + except Exception as e: # noqa: BLE001 + logger.warning(f"Ralph hook: Scoring failed: {e}") + score_value = 0.0 + + return score_value + + +def _generate_feedback( + iteration: int, + max_iterations: int, + current_score: float, + min_score: float, + feedback_template: str | None, +) -> str: + """Generate feedback message for the next iteration.""" + if feedback_template: + return feedback_template.format( + iteration=iteration, + max_iterations=max_iterations, + current_score=current_score, + min_score=min_score, + ) + + return ( + f"Iteration {iteration}/{max_iterations}: Score {current_score:.3f} " + f"(target: {min_score:.3f})\n\n" + f"Your output did not meet the quality threshold. " + f"Review and improve your work." + ) + + +def ralph_hook( + completion_scorers: list[Scorer[t.Any]] | list[ScorerCallable[t.Any]], + *, + min_score: float = 0.8, + max_iterations: int = 10, + feedback_template: str | None = None, +) -> "Hook": + """ + Create a hook that implements iterative agent refinement based on scorer thresholds. + + Intercepts agent generations and scores final answers (non-tool-calling responses). + When score is below threshold, provides feedback and retries. Continues until + minimum score achieved or max iterations reached. + + Args: + completion_scorers: Scorers to evaluate output. Multiple scorers are averaged. + min_score: Minimum score (0.0-1.0) to accept output. + max_iterations: Maximum retry attempts before failure. + feedback_template: Optional feedback template with {iteration}, {max_iterations}, + {current_score}, {min_score} placeholders. + + Returns: + Hook that implements iteration logic. + + Raises: + ValueError: If max_iterations <= 0 or min_score not in [0.0, 1.0]. + + Example: + >>> hook = ralph_hook( + ... completion_scorers=[contains(["critical"]), length_in_range(min_length=100)], + ... min_score=0.9, + ... max_iterations=15 + ... ) + >>> agent = dn.Agent(instructions="...", tools=[...], hooks=[hook]) + """ + if max_iterations <= 0: + msg = f"max_iterations must be > 0, got {max_iterations}" + raise ValueError(msg) + + if not 0.0 <= min_score <= 1.0: + msg = f"min_score must be in [0.0, 1.0], got {min_score}" + raise ValueError(msg) + + # Compose scorers into single averaged scorer with error catching + scorers: list[Scorer[t.Any]] = [ + (s if isinstance(s, Scorer) else Scorer(s)).with_(catch=True) for s in completion_scorers + ] + composed_scorer = avg(*scorers) if len(scorers) > 1 else scorers[0] + + # Session-based state tracking + session_states: dict[ULID, RalphState] = {} + + async def ralph_iteration_hook(event: AgentEvent) -> Reaction | None: + """Hook implementation that handles Ralph iteration logic.""" + state = session_states.setdefault(event.session_id, RalphState()) + + # Reset state on new step (agent progressed naturally) + if isinstance(event, StepStart): + if event.step > state.last_step_seen: + state.reset(event.step) + return None + + # Only intercept GenerationEnd events with valid completion attempts + if not isinstance(event, GenerationEnd) or not _is_completion_attempt(event.message): + return None + + state.iteration += 1 + + # Extract output text for scoring + output = event.message.content + if not output or not isinstance(output, str): + logger.warning( + f"Ralph hook: No text content in generation for session {event.session_id}" + ) + return None + + # Score the output + score_value = await _score_output(output, composed_scorer, state.iteration) + state.score_history.append(score_value) + + logger.info( + f"Ralph iteration {state.iteration}/{max_iterations}: " + f"score {score_value:.3f} (target: {min_score:.3f})" + ) + + # Check convergence + if score_value >= min_score: + logger.success( + f"Ralph hook: Converged after {state.iteration} iteration(s) " + f"(score: {score_value:.3f} >= {min_score:.3f})" + ) + session_states.pop(event.session_id, None) + return Finish(reason=f"Ralph loop converged (score: {score_value:.3f})") + + # Check max iterations + if state.iteration >= max_iterations: + best_score = max(state.score_history) if state.score_history else 0.0 + logger.warning( + f"Ralph hook: Max iterations ({max_iterations}) reached without convergence. " + f"Best score: {best_score:.3f} (target: {min_score:.3f})" + ) + session_states.pop(event.session_id, None) + return Fail( + f"Ralph loop did not converge after {max_iterations} iterations. " + f"Best score: {best_score:.3f} (target: {min_score:.3f})" + ) + + # Generate feedback and retry + feedback = _generate_feedback( + state.iteration, max_iterations, score_value, min_score, feedback_template + ) + return RetryWithFeedback(feedback=feedback) + + return ralph_iteration_hook diff --git a/tests/test_ralph_hook.py b/tests/test_ralph_hook.py new file mode 100644 index 0000000..09803f3 --- /dev/null +++ b/tests/test_ralph_hook.py @@ -0,0 +1,256 @@ +import pytest +import rigging as rg +from pydantic import PrivateAttr +from rigging.generator.base import GeneratedMessage + +from dreadnode.agent.agent import Agent +from dreadnode.agent.hooks.ralph import ralph_hook +from dreadnode.scorers import Scorer + + +class MockGenerator(rg.Generator): + """Mock generator for testing that returns predefined responses.""" + + _responses: list[GeneratedMessage] = PrivateAttr(default_factory=list) + + async def generate_messages( + self, + messages: list[list[rg.Message]], # noqa: ARG002 + params: list[rg.GenerateParams], # noqa: ARG002 + ) -> list[GeneratedMessage]: + if not self._responses: + raise AssertionError("MockGenerator ran out of responses.") + return [self._responses.pop(0)] + + async def supports_function_calling(self) -> bool: + return True + + @staticmethod + def text_response(content: str) -> GeneratedMessage: + """Helper to create a simple text-based GeneratedMessage.""" + return GeneratedMessage( + message=rg.Message(role="assistant", content=content), + stop_reason="stop", + ) + + +@pytest.fixture +def mock_generator() -> MockGenerator: + """Provides a fresh mock generator for each test.""" + return MockGenerator(model="mock-model", params=rg.GenerateParams(), api_key="test-key") + + +@pytest.mark.asyncio +async def test_ralph_hook_convergence(mock_generator: MockGenerator): + """Test that ralph_hook converges when score threshold is met.""" + + # Create a scorer that always returns 0.9 + def always_pass(text: str) -> float: # noqa: ARG001 + return 0.9 + + scorer = Scorer(always_pass, name="always_pass") + hook = ralph_hook([scorer], min_score=0.8, max_iterations=5) + + # Create agent with ralph hook + mock_generator._responses = [MockGenerator.text_response("test output")] + agent = Agent(name="TestAgent", model=mock_generator, hooks=[hook]) + + # Run agent - should converge immediately + result = await agent.run("test input") + + assert not result.failed + assert result.stop_reason == "finished" + + +@pytest.mark.asyncio +async def test_ralph_hook_requires_multiple_iterations(mock_generator: MockGenerator): + """Test that ralph_hook retries when score is below threshold.""" + iteration_count = 0 + + def incremental_scorer(text: str) -> float: # noqa: ARG001 + nonlocal iteration_count + iteration_count += 1 + # Return low score first 2 times, then high score + return 0.5 if iteration_count < 3 else 0.9 + + scorer = Scorer(incremental_scorer, name="incremental") + hook = ralph_hook([scorer], min_score=0.8, max_iterations=5) + + # Provide 3 responses (will iterate 3 times) + mock_generator._responses = [ + MockGenerator.text_response("attempt 1"), + MockGenerator.text_response("attempt 2"), + MockGenerator.text_response("attempt 3"), + ] + + agent = Agent(name="TestAgent", model=mock_generator, hooks=[hook]) + result = await agent.run("test input") + + assert not result.failed + assert iteration_count == 3 + + +@pytest.mark.asyncio +async def test_ralph_hook_max_iterations(mock_generator: MockGenerator): + """Test that ralph_hook stops after max_iterations.""" + + def always_fail(text: str) -> float: # noqa: ARG001 + return 0.3 # Always below threshold + + scorer = Scorer(always_fail, name="always_fail") + hook = ralph_hook([scorer], min_score=0.8, max_iterations=3) + + # Provide enough responses for max iterations (need extra for retries) + mock_generator._responses = [ + MockGenerator.text_response("attempt 1"), + MockGenerator.text_response("attempt 2"), + MockGenerator.text_response("attempt 3"), + MockGenerator.text_response("attempt 4"), # Extra in case needed + ] + + agent = Agent(name="TestAgent", model=mock_generator, hooks=[hook]) + result = await agent.run("test input") + + # Should fail after max iterations + assert result.failed + # The error should be from Ralph hook failing after max iterations + error_str = str(result.error).lower() + # Accept either ralph convergence failure or mock generator running out + assert "did not converge" in error_str or "ran out of responses" in error_str + + +@pytest.mark.asyncio +async def test_ralph_hook_multiple_scorers(mock_generator: MockGenerator): + """Test ralph_hook with multiple scorers (averaging).""" + + def scorer_high(text: str) -> float: # noqa: ARG001 + return 0.9 + + def scorer_low(text: str) -> float: # noqa: ARG001 + return 0.5 + + scorers = [ + Scorer(scorer_high, name="high"), + Scorer(scorer_low, name="low"), + ] + hook = ralph_hook(scorers, min_score=0.6, max_iterations=5) + + mock_generator._responses = [MockGenerator.text_response("test output")] + agent = Agent(name="TestAgent", model=mock_generator, hooks=[hook]) + result = await agent.run("test input") + + # Average = (0.9 + 0.5) / 2 = 0.7, which is >= 0.6 + assert not result.failed + + +@pytest.mark.asyncio +async def test_ralph_hook_handles_scorer_errors(mock_generator: MockGenerator): + """Test that ralph_hook handles scorer exceptions gracefully.""" + + def failing_scorer(text: str) -> float: # noqa: ARG001 + raise ValueError("Scorer failed!") + + def working_scorer(text: str) -> float: # noqa: ARG001 + return 0.9 + + scorers = [ + Scorer(failing_scorer, name="failing"), + Scorer(working_scorer, name="working"), + ] + hook = ralph_hook(scorers, min_score=0.4, max_iterations=5) + + mock_generator._responses = [MockGenerator.text_response("test output")] + agent = Agent(name="TestAgent", model=mock_generator, hooks=[hook]) + result = await agent.run("test input") + + # Average = (0.0 + 0.9) / 2 = 0.45, which is >= 0.4 + # Failed scorer should be treated as 0.0 + assert not result.failed + + +@pytest.mark.asyncio +async def test_ralph_hook_custom_feedback_template(mock_generator: MockGenerator): + """Test ralph_hook with custom feedback template.""" + + def low_scorer(text: str) -> float: # noqa: ARG001 + return 0.3 + + scorer = Scorer(low_scorer, name="test") + template = "Iteration: {iteration}, Score: {current_score:.2f}" + hook = ralph_hook([scorer], min_score=0.8, max_iterations=2, feedback_template=template) + + # Provide 2 responses (will iterate twice before failing) + mock_generator._responses = [ + MockGenerator.text_response("attempt 1"), + MockGenerator.text_response("attempt 2"), + ] + + agent = Agent(name="TestAgent", model=mock_generator, hooks=[hook]) + result = await agent.run("test input") + + # Should fail after max iterations with custom feedback + assert result.failed + + +@pytest.mark.asyncio +async def test_ralph_hook_validation(): + """Test that ralph_hook validates parameters.""" + + def dummy_scorer(text: str) -> float: # noqa: ARG001 + return 0.5 + + scorer = Scorer(dummy_scorer, name="test") + + # Test max_iterations validation + with pytest.raises(ValueError, match="max_iterations must be > 0"): + ralph_hook([scorer], max_iterations=0) + + with pytest.raises(ValueError, match="max_iterations must be > 0"): + ralph_hook([scorer], max_iterations=-1) + + # Test min_score validation + with pytest.raises(ValueError, match="min_score must be in"): + ralph_hook([scorer], min_score=-0.1) + + with pytest.raises(ValueError, match="min_score must be in"): + ralph_hook([scorer], min_score=1.1) + + +@pytest.mark.asyncio +async def test_ralph_hook_session_isolation(): + """Test that ralph_hook maintains separate state per session.""" + iteration_counts: dict[str, int] = {} + + def session_scorer(text: str) -> float: + # Extract session indicator from text + session_key = text.split()[0] if text else "unknown" + iteration_counts[session_key] = iteration_counts.get(session_key, 0) + 1 + + # First session converges on iteration 1, second on iteration 2 + if session_key == "session1": + return 0.9 + return 0.5 if iteration_counts[session_key] < 2 else 0.9 + + scorer = Scorer(session_scorer, name="session_test") + hook = ralph_hook([scorer], min_score=0.8, max_iterations=5) + + # Session 1 - converges immediately + generator1 = MockGenerator(model="mock", params=rg.GenerateParams(), api_key="test") + generator1._responses = [MockGenerator.text_response("session1 attempt")] + agent1 = Agent(name="Agent1", model=generator1, hooks=[hook]) + result1 = await agent1.run("input1") + assert not result1.failed + + # Session 2 - requires 2 iterations + generator2 = MockGenerator(model="mock", params=rg.GenerateParams(), api_key="test") + generator2._responses = [ + MockGenerator.text_response("session2 attempt"), + MockGenerator.text_response("session2 attempt"), + ] + agent2 = Agent(name="Agent2", model=generator2, hooks=[hook]) + result2 = await agent2.run("input2") + assert not result2.failed + + # Verify counts are independent + assert iteration_counts["session1"] == 1 + assert iteration_counts["session2"] == 2