Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
capability_extensions: dict[str, Any] | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -143,6 +144,10 @@ def __init__(
# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()

# Capability extensions to include in initialize request
# These are merged into ClientCapabilities using Pydantic's extra fields
self._capability_extensions = capability_extensions or {}

async def initialize(self) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
Expand Down Expand Up @@ -177,6 +182,7 @@ async def initialize(self) -> types.InitializeResult:
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
**self._capability_extensions,
Copy link
Member

@Kludex Kludex Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should allow any key here, this will be an issue if in the future ClientCapabilities has a new key but it's used by someone.

Suggested change
**self._capability_extensions,
extensions=self._capability_extensions,

If the spec doesn't have this explicit, I think the python SDK should.

),
client_info=self._client_info,
),
Expand Down
70 changes: 70 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,73 @@ async def mock_server():
await session.initialize()

await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)


@pytest.mark.anyio
async def test_client_session_capability_extensions():
"""Test that capability_extensions are included in the initialize request."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

received_capabilities = None

# Define capability extensions (e.g., UI extension)
capability_extensions = {"extensions": {"io.modelcontextprotocol/ui": {"mimeTypes": ["text/html;profile=mcp-app"]}}}

async def mock_server():
nonlocal received_capabilities

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities

result = ServerResult(
InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
capability_extensions=capability_extensions,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

# Assert that the capability extensions were included in the request
assert received_capabilities is not None
# The extensions should be present via Pydantic's extra fields
caps_dict = received_capabilities.model_dump()
assert "extensions" in caps_dict
assert "io.modelcontextprotocol/ui" in caps_dict["extensions"]
assert caps_dict["extensions"]["io.modelcontextprotocol/ui"]["mimeTypes"] == ["text/html;profile=mcp-app"]