diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7aeee2cd8..77f4192b5 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -121,6 +121,7 @@ def __init__( *, sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, + capability_extensions: dict[str, Any] | None = None, ) -> None: super().__init__( read_stream, @@ -143,6 +144,10 @@ def __init__( # Experimental: Task handlers (use defaults if not provided) self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() + # Capability extensions to include in initialize request + # These are merged into ClientCapabilities using Pydantic's extra fields + self._capability_extensions = capability_extensions or {} + async def initialize(self) -> types.InitializeResult: sampling = ( (self._sampling_capabilities or types.SamplingCapability()) @@ -177,6 +182,7 @@ async def initialize(self) -> types.InitializeResult: experimental=None, roots=roots, tasks=self._task_handlers.build_capability(), + **self._capability_extensions, ), client_info=self._client_info, ), diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 78df8ed19..e45135ee8 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -768,3 +768,73 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_client_session_capability_extensions(): + """Test that capability_extensions are included in the initialize request.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + # Define capability extensions (e.g., UI extension) + capability_extensions = {"extensions": {"io.modelcontextprotocol/ui": {"mimeTypes": ["text/html;profile=mcp-app"]}}} + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + capability_extensions=capability_extensions, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that the capability extensions were included in the request + assert received_capabilities is not None + # The extensions should be present via Pydantic's extra fields + caps_dict = received_capabilities.model_dump() + assert "extensions" in caps_dict + assert "io.modelcontextprotocol/ui" in caps_dict["extensions"] + assert caps_dict["extensions"]["io.modelcontextprotocol/ui"]["mimeTypes"] == ["text/html;profile=mcp-app"]