Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@
from typing import Optional
from collections.abc import Iterator

try:
from anyio import create_memory_object_stream, create_task_group
from mcp.types import (
JSONRPCMessage,
JSONRPCRequest,
)
from mcp.shared.message import SessionMessage
except ImportError:
create_memory_object_stream = None
create_task_group = None
JSONRPCMessage = None
JSONRPCRequest = None
SessionMessage = None


SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"

Expand Down Expand Up @@ -592,6 +606,84 @@ def suppress_deprecation_warnings():
yield


@pytest.fixture
def get_initialization_payload():
def inner(request_id: str):
return SessionMessage( # type: ignore
message=JSONRPCMessage( # type: ignore
root=JSONRPCRequest( # type: ignore
jsonrpc="2.0",
id=request_id,
method="initialize",
params={
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
)
)
)

return inner


@pytest.fixture
def get_mcp_command_payload():
def inner(method: str, params, request_id: str):
return SessionMessage( # type: ignore
message=JSONRPCMessage( # type: ignore
root=JSONRPCRequest( # type: ignore
jsonrpc="2.0",
id=request_id,
method=method,
params=params,
)
)
)

return inner


@pytest.fixture
def stdio(get_initialization_payload, get_mcp_command_payload):
async def inner(server, method: str, params, request_id: str | None = None):
if request_id is None:
request_id = "1" # arbitrary

read_stream_writer, read_stream = create_memory_object_stream(0) # type: ignore
write_stream, write_stream_reader = create_memory_object_stream(0) # type: ignore

result = {}

async def run_server():
await server.run(
read_stream, write_stream, server.create_initialization_options()
)

async def simulate_client(tg, result):
init_request = get_initialization_payload("1")
await read_stream_writer.send(init_request)

await write_stream_reader.receive()

request = get_mcp_command_payload(
method, params=params, request_id=request_id
)
await read_stream_writer.send(request)

result["response"] = await write_stream_reader.receive()

tg.cancel_scope.cancel()

async with create_task_group() as tg: # type: ignore
tg.start_soon(run_server)
tg.start_soon(simulate_client, tg, result)

return result["response"]

return inner


class MockServerRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: N802
# Process an HTTP GET request and return a response.
Expand Down
Loading
Loading