Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This repo provides a Python implementation of Connect, including both client and

## Features

- **Clients**: Both synchronous and asynchronous clients backed by [httpx](https://www.python-httpx.org/)
- **Clients**: Both synchronous and asynchronous clients backed by [pyqwest](https://pyqwest.dev/)
- **Servers**: WSGI and ASGI server implementations for use with any Python app server
- **Type Safety**: Fully type-annotated, including the generated code
- **Compression**: Built-in support for gzip, brotli, and zstd compression
Expand Down Expand Up @@ -63,8 +63,8 @@ it can be referenced as `protoc-gen-connect-python`.
Then, you can use `protoc-gen-connect-python` as a local plugin:

```yaml
- local: .venv/bin/protoc-gen-connect-python
out: .
- local: .venv/bin/protoc-gen-connect-python
out: .
```

Alternatively, download a precompiled binary from the
Expand All @@ -79,18 +79,14 @@ For more usage details, see the [docs](./docs/usage.md).
### Basic Client Usage

```python
import httpx
from your_service_pb2 import HelloRequest, HelloResponse
from your_service_connect import HelloServiceClient

# Create async client
async def main():
async with httpx.AsyncClient() as session:
client = HelloServiceClient(
base_url="https://api.example.com",
session=session
)

async with HelloServiceClient(
base_url="https://api.example.com",
) as client:
# Make a unary RPC call
response = await client.say_hello(HelloRequest(name="World"))
print(response.message) # "Hello, World!"
Expand All @@ -117,18 +113,14 @@ app = HelloServiceASGIApplication(MyHelloService())
### Basic Client Usage (Synchronous)

```python
import httpx
from your_service_pb2 import HelloRequest
from your_service_connect import HelloServiceClientSync

# Create sync client
def main():
with httpx.Client() as session:
client = HelloServiceClientSync(
base_url="https://api.example.com",
session=session
)

with HelloServiceClientSync(
base_url="https://api.example.com",
) as client:
# Make a unary RPC call
response = client.say_hello(HelloRequest(name="World"))
print(response.message) # "Hello, World!"
Expand Down
99 changes: 27 additions & 72 deletions conformance/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import contextlib
import multiprocessing
import queue
import ssl
import sys
import time
import traceback
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Literal, TypeVar, get_args

import httpx
from _util import create_standard_streams
from gen.connectrpc.conformance.v1.client_compat_pb2 import (
ClientCompatRequest,
Expand Down Expand Up @@ -40,9 +37,8 @@
UnimplementedRequest,
)
from google.protobuf.message import Message
from pyqwest import HTTPTransport, SyncHTTPTransport
from pyqwest import Client, HTTPTransport, SyncClient, SyncHTTPTransport
from pyqwest import HTTPVersion as PyQwestHTTPVersion
from pyqwest.httpx import AsyncPyqwestTransport, PyqwestTransport

from connectrpc.client import ResponseMetadata
from connectrpc.code import Code
Expand Down Expand Up @@ -118,41 +114,8 @@ def _unpack_request(message: Any, request: T) -> T:
return request


async def httpx_client_kwargs(test_request: ClientCompatRequest) -> dict:
kwargs = {}
match test_request.http_version:
case HTTPVersion.HTTP_VERSION_1:
kwargs["http1"] = True
kwargs["http2"] = False
case HTTPVersion.HTTP_VERSION_2:
kwargs["http1"] = False
kwargs["http2"] = True
if test_request.server_tls_cert:
ctx = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH,
cadata=test_request.server_tls_cert.decode(),
)
if test_request.HasField("client_tls_creds"):

def load_certs() -> None:
with (
NamedTemporaryFile() as cert_file,
NamedTemporaryFile() as key_file,
):
cert_file.write(test_request.client_tls_creds.cert)
cert_file.flush()
key_file.write(test_request.client_tls_creds.key)
key_file.flush()
ctx.load_cert_chain(certfile=cert_file.name, keyfile=key_file.name)

await asyncio.to_thread(load_certs)
kwargs["verify"] = ctx

return kwargs


def pyqwest_client_kwargs(test_request: ClientCompatRequest) -> dict:
kwargs: dict = {"enable_gzip": True, "enable_brotli": True, "enable_zstd": True}
kwargs: dict = {}
match test_request.http_version:
case HTTPVersion.HTTP_VERSION_1:
kwargs["http_version"] = PyQwestHTTPVersion.HTTP1
Expand All @@ -169,28 +132,26 @@ def pyqwest_client_kwargs(test_request: ClientCompatRequest) -> dict:

@contextlib.asynccontextmanager
async def client_sync(
test_request: ClientCompatRequest, client_type: Client
test_request: ClientCompatRequest,
) -> AsyncIterator[ConformanceServiceClientSync]:
read_max_bytes = None
if test_request.message_receive_limit:
read_max_bytes = test_request.message_receive_limit
scheme = "https" if test_request.server_tls_cert else "http"
args = pyqwest_client_kwargs(test_request)

cleanup = contextlib.ExitStack()
match client_type:
case "httpx":
args = await httpx_client_kwargs(test_request)
session = cleanup.enter_context(httpx.Client(**args))
case "pyqwest":
args = pyqwest_client_kwargs(test_request)
http_transport = cleanup.enter_context(SyncHTTPTransport(**args))
transport = cleanup.enter_context(PyqwestTransport(http_transport))
session = cleanup.enter_context(httpx.Client(transport=transport))
if args:
transport = cleanup.enter_context(SyncHTTPTransport(**args))
http_client = SyncClient(transport)
else:
http_client = None

with (
cleanup,
ConformanceServiceClientSync(
f"{scheme}://{test_request.host}:{test_request.port}",
session=session,
http_client=http_client,
send_compression=_convert_compression(test_request.compression),
proto_json=test_request.codec == Codec.CODEC_JSON,
grpc=test_request.protocol == Protocol.PROTOCOL_GRPC,
Expand All @@ -202,32 +163,29 @@ async def client_sync(

@contextlib.asynccontextmanager
async def client_async(
test_request: ClientCompatRequest, client_type: Client
test_request: ClientCompatRequest,
) -> AsyncIterator[ConformanceServiceClient]:
read_max_bytes = None
if test_request.message_receive_limit:
read_max_bytes = test_request.message_receive_limit
scheme = "https" if test_request.server_tls_cert else "http"
args = pyqwest_client_kwargs(test_request)

cleanup = contextlib.AsyncExitStack()
match client_type:
case "httpx":
args = await httpx_client_kwargs(test_request)
session = await cleanup.enter_async_context(httpx.AsyncClient(**args))
case "pyqwest":
args = pyqwest_client_kwargs(test_request)
http_transport = await cleanup.enter_async_context(HTTPTransport(**args))
transport = await cleanup.enter_async_context(
AsyncPyqwestTransport(http_transport)
)
session = await cleanup.enter_async_context(
httpx.AsyncClient(transport=transport)
)
if args:
transport = HTTPTransport(**args)
# Type parameter for enter_async_context requires coroutine even though
# implementation doesn't. We can directly push aexit to work around it.
cleanup.push_async_exit(transport.__aexit__)
http_client = Client(transport)
else:
http_client = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added coverage of a client without an explicit transport which could have been good before too


async with (
cleanup,
ConformanceServiceClient(
f"{scheme}://{test_request.host}:{test_request.port}",
session=session,
http_client=http_client,
send_compression=_convert_compression(test_request.compression),
proto_json=test_request.codec == Codec.CODEC_JSON,
grpc=test_request.protocol == Protocol.PROTOCOL_GRPC,
Expand All @@ -238,7 +196,7 @@ async def client_async(


async def _run_test(
mode: Mode, test_request: ClientCompatRequest, client_type: Client
mode: Mode, test_request: ClientCompatRequest
) -> ClientCompatResponse:
test_response = ClientCompatResponse()
test_response.test_name = test_request.test_name
Expand All @@ -260,7 +218,7 @@ async def _run_test(
request_closed = asyncio.Event()
match mode:
case "sync":
async with client_sync(test_request, client_type) as client:
async with client_sync(test_request) as client:
match test_request.method:
case "BidiStream":
request_queue = queue.Queue()
Expand Down Expand Up @@ -468,7 +426,7 @@ def send_unary_request_sync(
task.cancel()
await task
case "async":
async with client_async(test_request, client_type) as client:
async with client_async(test_request) as client:
match test_request.method:
case "BidiStream":
request_queue = asyncio.Queue()
Expand Down Expand Up @@ -691,20 +649,17 @@ async def send_unary_request(


Mode = Literal["sync", "async"]
Client = Literal["httpx", "pyqwest"]


class Args(argparse.Namespace):
mode: Mode
client: Client
parallel: int


async def main() -> None:
parser = argparse.ArgumentParser(description="Conformance client")
parser.add_argument("--mode", choices=get_args(Mode))
parser.add_argument("--parallel", type=int, default=multiprocessing.cpu_count() * 4)
parser.add_argument("--client", choices=get_args(Client))
args = parser.parse_args(namespace=Args())

stdin, stdout = await create_standard_streams()
Expand All @@ -724,7 +679,7 @@ async def main() -> None:

async def task(request: ClientCompatRequest) -> None:
async with sema:
response = await _run_test(args.mode, request, args.client)
response = await _run_test(args.mode, request)

response_buf = response.SerializeToString()
size_buf = len(response_buf).to_bytes(4, byteorder="big")
Expand Down
44 changes: 4 additions & 40 deletions conformance/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,12 @@
"Client Cancellation/**",
]

_httpx_opts = [
# Trailers not supported
"--skip",
"**/Protocol:PROTOCOL_GRPC/**",
"--skip",
"gRPC Trailers/**",
"--skip",
"gRPC Unexpected Responses/**",
"--skip",
"gRPC Empty Responses/**",
"--skip",
"gRPC Proto Sub-Format Responses/**",
# Bidirectional streaming not supported
"--skip",
"**/full-duplex/**",
# Cancellation delivery isn't reliable
"--known-flaky",
"Client Cancellation/**",
"--known-flaky",
"Timeouts/**",
]


@pytest.mark.parametrize("client", ["httpx", "pyqwest"])
def test_client_sync(client: str) -> None:
def test_client_sync() -> None:
args = maybe_patch_args_with_debug(
[sys.executable, _client_py_path, "--mode", "sync", "--client", client]
[sys.executable, _client_py_path, "--mode", "sync"]
)

opts = []
match client:
case "httpx":
opts = _httpx_opts

result = subprocess.run(
[
"go",
Expand All @@ -61,7 +33,6 @@ def test_client_sync(client: str) -> None:
_config_path,
"--mode",
"client",
*opts,
*_skipped_tests_sync,
"--",
*args,
Expand All @@ -74,17 +45,11 @@ def test_client_sync(client: str) -> None:
pytest.fail(f"\n{result.stdout}\n{result.stderr}")


@pytest.mark.parametrize("client", ["httpx", "pyqwest"])
def test_client_async(client: str) -> None:
def test_client_async() -> None:
args = maybe_patch_args_with_debug(
[sys.executable, _client_py_path, "--mode", "async", "--client", client]
[sys.executable, _client_py_path, "--mode", "async"]
)

opts = []
match client:
case "httpx":
opts = _httpx_opts

result = subprocess.run(
[
"go",
Expand All @@ -94,7 +59,6 @@ def test_client_async(client: str) -> None:
_config_path,
"--mode",
"client",
*opts,
"--",
*args,
],
Expand Down
Loading