diff --git a/.gitignore b/.gitignore index d2e4ca808..559f8b9ef 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__/ dist/ poetry.toml +.venv/ diff --git a/src/cohere/base_client.py b/src/cohere/base_client.py index 1b6766a73..b95f8be7a 100644 --- a/src/cohere/base_client.py +++ b/src/cohere/base_client.py @@ -1125,6 +1125,123 @@ def embed( ) return _response.data + def embed_stream( + self, + *, + texts: typing.Optional[typing.Sequence[str]] = OMIT, + model: typing.Optional[str] = OMIT, + input_type: typing.Optional[EmbedInputType] = OMIT, + embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, + truncate: typing.Optional[EmbedRequestTruncate] = OMIT, + batch_size: int = 10, + request_options: typing.Optional[RequestOptions] = None, + ) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding] + """ + Memory-efficient streaming version of embed that yields embeddings one at a time. + + This method processes texts in batches and yields individual embeddings as they are + parsed from the response, without loading all embeddings into memory at once. + Ideal for processing large datasets where memory usage is a concern. + + Parameters + ---------- + texts : typing.Optional[typing.Sequence[str]] + An array of strings for the model to embed. Will be processed in batches. + + model : typing.Optional[str] + ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). + + input_type : typing.Optional[EmbedInputType] + Specifies the type of input passed to the model. + + embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] + Specifies the types of embeddings you want to get back. + + truncate : typing.Optional[EmbedRequestTruncate] + One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. + + batch_size : int + Number of texts to process in each batch. Default is 10. + Lower values use less memory but may be slower overall. + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Yields + ------ + StreamedEmbedding + Individual embeddings as they are parsed from the response. + + Examples + -------- + from cohere import Client + + client = Client( + client_name="YOUR_CLIENT_NAME", + token="YOUR_TOKEN", + ) + + # Process embeddings one at a time without loading all into memory + for embedding in client.embed_stream( + texts=["hello", "goodbye", "how are you"], + model="embed-v4.0", + batch_size=2 + ): + print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...") + # Process/save embedding immediately + """ + # Validate batch_size + if batch_size < 1: + raise ValueError("batch_size must be at least 1") + + # Handle OMIT sentinel and empty texts + if texts is None or texts is OMIT: + return + if not texts: + return + + from .streaming_utils import StreamingEmbedParser + + # Process texts in batches + texts_list = list(texts) + total_embeddings_yielded = 0 + + for batch_start in range(0, len(texts_list), batch_size): + batch_end = min(batch_start + batch_size, len(texts_list)) + batch_texts = texts_list[batch_start:batch_end] + + # Get response for this batch + response = self._raw_client.embed( + texts=batch_texts, + model=model, + input_type=input_type, + embedding_types=embedding_types, + truncate=truncate, + request_options=request_options, + ) + + # Parse embeddings from response incrementally + parser = StreamingEmbedParser(response._response, batch_texts) + # Track used indices to handle duplicate texts correctly + used_batch_indices = set() + + for embedding in parser.iter_embeddings(): + # The parser sets embedding.text correctly for multiple embedding types + # Adjust the global index based on text position in batch + if embedding.text and embedding.text in batch_texts: + # Find the next unused occurrence of this text in the batch + # This handles duplicate texts correctly + text_idx_in_batch = None + for idx, text in enumerate(batch_texts): + if text == embedding.text and idx not in used_batch_indices: + text_idx_in_batch = idx + used_batch_indices.add(idx) + break + + if text_idx_in_batch is not None: + embedding.index = batch_start + text_idx_in_batch + yield embedding + def rerank( self, *, diff --git a/src/cohere/client.py b/src/cohere/client.py index 501338d3c..600ce1fed 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -1,24 +1,23 @@ import asyncio +import logging import os import typing from concurrent.futures import ThreadPoolExecutor -from tokenizers import Tokenizer # type: ignore -import logging import httpx - -from cohere.types.detokenize_response import DetokenizeResponse -from cohere.types.tokenize_response import TokenizeResponse - -from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate -from .base_client import BaseCohere, AsyncBaseCohere, OMIT +from . import EmbeddingType, EmbedInputType, EmbedRequestTruncate, EmbedResponse +from .base_client import OMIT, AsyncBaseCohere, BaseCohere from .config import embed_batch_size from .core import RequestOptions from .environment import ClientEnvironment -from .manually_maintained.cache import CacheMixin from .manually_maintained import tokenizers as local_tokenizers +from .manually_maintained.cache import CacheMixin from .overrides import run_overrides -from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils +from .utils import AsyncSdkUtils, SyncSdkUtils, async_wait, merge_embed_responses, wait +from tokenizers import Tokenizer # type: ignore + +from cohere.types.detokenize_response import DetokenizeResponse +from cohere.types.tokenize_response import TokenizeResponse logger = logging.getLogger(__name__) run_overrides() @@ -188,6 +187,8 @@ def embed( truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, batching: typing.Optional[bool] = True, + batch_size: typing.Optional[int] = None, + max_workers: typing.Optional[int] = None, ) -> EmbedResponse: # skip batching for images for now if batching is False or images is not OMIT: @@ -202,24 +203,39 @@ def embed( request_options=request_options, ) + # Validate batch_size + if batch_size is not None and batch_size < 1: + raise ValueError("batch_size must be at least 1") + textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] - texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)] - - responses = [ - response - for response in self._executor.map( - lambda text_batch: BaseCohere.embed( - self, - texts=text_batch, - model=model, - input_type=input_type, - embedding_types=embedding_types, - truncate=truncate, - request_options=request_options, - ), - texts_batches, - ) - ] + effective_batch_size = batch_size if batch_size is not None else embed_batch_size + texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)] + + # Use custom executor if max_workers is specified + executor = self._executor + if max_workers is not None: + executor = ThreadPoolExecutor(max_workers=max_workers) + + try: + responses = [ + response + for response in executor.map( + lambda text_batch: BaseCohere.embed( + self, + texts=text_batch, + model=model, + input_type=input_type, + embedding_types=embedding_types, + truncate=truncate, + request_options=request_options, + ), + texts_batches, + ) + ] + finally: + # Clean up custom executor if created + if max_workers is not None: + executor.shutdown(wait=False) return merge_embed_responses(responses) @@ -380,6 +396,8 @@ async def embed( truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, batching: typing.Optional[bool] = True, + batch_size: typing.Optional[int] = None, + max_workers: typing.Optional[int] = None, ) -> EmbedResponse: # skip batching for images for now if batching is False or images is not OMIT: @@ -394,9 +412,26 @@ async def embed( request_options=request_options, ) - textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] - texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)] + # Validate batch_size + if batch_size is not None and batch_size < 1: + raise ValueError("batch_size must be at least 1") + textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] + effective_batch_size = batch_size if batch_size is not None else embed_batch_size + texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)] + + # Note: max_workers parameter is not applicable to async version since asyncio.gather + # handles concurrency differently than ThreadPoolExecutor + if max_workers is not None: + import warnings + warnings.warn( + "The 'max_workers' parameter is not supported for AsyncClient. " + "Async clients use asyncio.gather() for concurrent execution, which " + "automatically manages concurrency. The parameter will be ignored.", + UserWarning, + stacklevel=2 + ) + responses = typing.cast( typing.List[EmbedResponse], await asyncio.gather( diff --git a/src/cohere/streaming_utils.py b/src/cohere/streaming_utils.py new file mode 100644 index 000000000..d035fd56b --- /dev/null +++ b/src/cohere/streaming_utils.py @@ -0,0 +1,215 @@ +"""Utilities for streaming large responses without loading everything into memory.""" + +from __future__ import annotations + +import io +import json +from dataclasses import dataclass +from typing import Iterator, List, Optional, Union + +import httpx + +try: + import ijson # type: ignore + IJSON_AVAILABLE = True +except ImportError: + IJSON_AVAILABLE = False + + +@dataclass +class StreamedEmbedding: + """A single embedding that can be processed without loading all embeddings into memory.""" + index: int + embedding: Union[List[float], List[int], str] # float, int8, uint8, binary, ubinary, base64 + embedding_type: str + text: Optional[str] = None + + +class StreamingEmbedParser: + """ + Parses embed responses incrementally using ijson for memory efficiency. + Falls back to regular JSON parsing if ijson is not available. + """ + + def __init__(self, response: httpx.Response, batch_texts: Optional[List[str]] = None): + """ + Initialize the streaming parser. + + Args: + response: The httpx response object + batch_texts: The original texts for this batch (for correlation) + """ + self.response = response + self.batch_texts = batch_texts or [] + self.embeddings_yielded = 0 + + def iter_embeddings(self) -> Iterator[StreamedEmbedding]: + """ + Iterate over embeddings one at a time without loading all into memory. + + Yields: + StreamedEmbedding objects as they are parsed from the response + """ + # Try to get response content as bytes for ijson + response_content: Optional[bytes] = None + try: + content = self.response.content + if isinstance(content, bytes): + response_content = content + except Exception: + pass + + if not IJSON_AVAILABLE or response_content is None: + # Fallback to regular parsing if ijson not available or no bytes content + yield from self._iter_embeddings_fallback() + return + + try: + # Use ijson for memory-efficient parsing + # Collect all embeddings first to avoid partial yields before failure + parser = ijson.parse(io.BytesIO(response_content)) + embeddings = list(self._parse_with_ijson(parser)) + # Only yield after successful complete parsing + yield from embeddings + except Exception: + # If ijson parsing fails, fallback to regular parsing using buffered content + # Reset embeddings_yielded since we collected but didn't yield + self.embeddings_yielded = 0 + data = json.loads(response_content) + yield from self._iter_embeddings_fallback_from_dict(data) + + def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]: + """Parse embeddings using ijson incremental parser.""" + current_path: List[str] = [] + current_embedding = [] + # Track text index separately per embedding type + # When multiple types requested, each text gets multiple embeddings + type_text_indices: dict = {} + embedding_type = "float" + response_type = None + in_embeddings = False + + for prefix, event, value in parser: + # Track current path + if event == 'map_key': + if current_path and current_path[-1] == 'embeddings': + # This is an embedding type key (float_, int8, etc.) + embedding_type = value.rstrip('_') + + # Detect response type + if prefix == 'response_type': + response_type = value + + # Handle embeddings based on response type + if response_type == 'embeddings_floats': + # Simple float array format + if prefix.startswith('embeddings.item.item'): + current_embedding.append(value) + elif prefix.startswith('embeddings.item') and event == 'end_array': + # Complete embedding + embedding_index = type_text_indices.get('float', 0) + text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None + yield StreamedEmbedding( + index=self.embeddings_yielded, + embedding=current_embedding, + embedding_type='float', + text=text + ) + self.embeddings_yielded += 1 + type_text_indices['float'] = embedding_index + 1 + current_embedding = [] + + elif response_type == 'embeddings_by_type': + # Complex format with multiple embedding types + # Pattern: embeddings..item.item + # ijson adds underscore to Python keywords like 'float' + for emb_type in ['float_', 'int8', 'uint8', 'binary', 'ubinary']: + type_name = emb_type.rstrip('_') + if prefix.startswith(f'embeddings.{emb_type}.item.item'): + current_embedding.append(value) + elif prefix.startswith(f'embeddings.{emb_type}.item') and event == 'end_array': + # Complete embedding of this type + # Track index per type - same text can have multiple embedding types + embedding_index = type_text_indices.get(type_name, 0) + text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None + yield StreamedEmbedding( + index=self.embeddings_yielded, + embedding=current_embedding, + embedding_type=type_name, + text=text + ) + self.embeddings_yielded += 1 + type_text_indices[type_name] = embedding_index + 1 + current_embedding = [] + + # Handle base64 embeddings (string format) + if prefix.startswith('embeddings.base64.item') and event == 'string': + embedding_index = type_text_indices.get('base64', 0) + text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None + yield StreamedEmbedding( + index=self.embeddings_yielded, + embedding=value, # base64 string + embedding_type='base64', + text=text + ) + self.embeddings_yielded += 1 + type_text_indices['base64'] = embedding_index + 1 + + def _iter_embeddings_fallback(self) -> Iterator[StreamedEmbedding]: + """Fallback method using regular JSON parsing.""" + # This still loads the full response but at least provides the same interface + if hasattr(self.response, 'json'): + data = self.response.json() + elif hasattr(self.response, '_response'): + data = self.response._response.json() # type: ignore + else: + raise ValueError("Response object does not have a json() method") + + yield from self._iter_embeddings_fallback_from_dict(data) + + def _iter_embeddings_fallback_from_dict(self, data: dict) -> Iterator[StreamedEmbedding]: + """Parse embeddings from a dictionary (used by fallback methods).""" + response_type = data.get('response_type', '') + + if response_type == 'embeddings_floats': + embeddings = data.get('embeddings', []) + texts = data.get('texts', []) + for i, embedding in enumerate(embeddings): + yield StreamedEmbedding( + index=self.embeddings_yielded + i, + embedding=embedding, + embedding_type='float', + text=texts[i] if i < len(texts) else None + ) + + elif response_type == 'embeddings_by_type': + embeddings_obj = data.get('embeddings', {}) + texts = data.get('texts', []) + + # Iterate through each embedding type + for emb_type, embeddings_list in embeddings_obj.items(): + type_name = emb_type.rstrip('_') + if isinstance(embeddings_list, list): + for i, embedding in enumerate(embeddings_list): + yield StreamedEmbedding( + index=self.embeddings_yielded, + embedding=embedding, + embedding_type=type_name, + text=texts[i] if i < len(texts) else None + ) + self.embeddings_yielded += 1 + + +def stream_embed_response(response: httpx.Response, texts: List[str]) -> Iterator[StreamedEmbedding]: + """ + Convenience function to stream embeddings from a response. + + Args: + response: The httpx response containing embeddings + texts: The original texts that were embedded + + Yields: + StreamedEmbedding objects + """ + parser = StreamingEmbedParser(response, texts) + yield from parser.iter_embeddings() \ No newline at end of file diff --git a/src/cohere/v2/client.py b/src/cohere/v2/client.py index 5edde15c2..78a7bb1b9 100644 --- a/src/cohere/v2/client.py +++ b/src/cohere/v2/client.py @@ -489,6 +489,128 @@ def embed( ) return _response.data + def embed_stream( + self, + *, + model: str, + input_type: EmbedInputType, + texts: typing.Optional[typing.Sequence[str]] = OMIT, + max_tokens: typing.Optional[int] = OMIT, + output_dimension: typing.Optional[int] = OMIT, + embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, + truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT, + batch_size: int = 10, + request_options: typing.Optional[RequestOptions] = None, + ) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding] + """ + Memory-efficient streaming version of embed that yields embeddings one at a time. + + This method processes texts in batches and yields individual embeddings as they are + parsed from the response, without loading all embeddings into memory at once. + Ideal for processing large datasets where memory usage is a concern. + + Note: This method only supports text embeddings. For image embeddings, use the + regular embed() method. + + Parameters + ---------- + model : str + ID of one of the available [Embedding models](https://docs.cohere.com/docs/cohere-embed). + + input_type : EmbedInputType + Specifies the type of input passed to the model. + + texts : typing.Optional[typing.Sequence[str]] + An array of strings for the model to embed. Will be processed in batches. + + images : typing.Optional[typing.Sequence[str]] + An array of image data URIs for the model to embed. + + max_tokens : typing.Optional[int] + The maximum number of tokens to embed per input. + + output_dimension : typing.Optional[int] + The number of dimensions of the output embedding. + + embedding_types : typing.Optional[typing.Sequence[EmbeddingType]] + Specifies the types of embeddings you want to get back. + + truncate : typing.Optional[V2EmbedRequestTruncate] + How to handle inputs longer than the maximum token length. + + batch_size : int + Number of texts to process in each batch. Default is 10. + Lower values use less memory but may be slower overall. + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Yields + ------ + StreamedEmbedding + Individual embeddings as they are parsed from the response. + + Examples + -------- + from cohere import Client + + client = Client( + client_name="YOUR_CLIENT_NAME", + token="YOUR_TOKEN", + ) + + # Process embeddings one at a time without loading all into memory + for embedding in client.v2.embed_stream( + model="embed-v4.0", + input_type="classification", + texts=["hello", "goodbye", "how are you"], + batch_size=2 + ): + print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...") + # Process/save embedding immediately + """ + # Validate batch_size + if batch_size < 1: + raise ValueError("batch_size must be at least 1") + + # Handle OMIT sentinel and empty texts + if texts is None or texts is OMIT: + return + if not texts: + return + + from ..streaming_utils import StreamingEmbedParser + + # Process texts in batches + texts_list = list(texts) + total_embeddings_yielded = 0 + + for batch_start in range(0, len(texts_list), batch_size): + batch_end = min(batch_start + batch_size, len(texts_list)) + batch_texts = texts_list[batch_start:batch_end] + + # Get response for this batch + response = self._raw_client.embed( + model=model, + input_type=input_type, + texts=batch_texts, + max_tokens=max_tokens, + output_dimension=output_dimension, + embedding_types=embedding_types, + truncate=truncate, + request_options=request_options, + ) + + # Parse embeddings from response incrementally + parser = StreamingEmbedParser(response._response, batch_texts) + for embedding in parser.iter_embeddings(): + # The parser sets embedding.text correctly for multiple embedding types + # Adjust the global index based on text position in batch + if embedding.text and embedding.text in batch_texts: + text_idx_in_batch = batch_texts.index(embedding.text) + embedding.index = batch_start + text_idx_in_batch + yield embedding + def rerank( self, *, diff --git a/tests/test_configurable_batch_size.py b/tests/test_configurable_batch_size.py new file mode 100644 index 000000000..50e4edb7d --- /dev/null +++ b/tests/test_configurable_batch_size.py @@ -0,0 +1,257 @@ +"""Tests for configurable batch size in embed method.""" + +import unittest +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, patch + +import cohere +from cohere import EmbedResponse +from cohere.base_client import AsyncBaseCohere, BaseCohere + + +class TestConfigurableBatchSize(unittest.TestCase): + """Test suite for configurable batch size functionality.""" + + def setUp(self): + """Set up test client.""" + self.api_key = "test-key" + self.client = cohere.Client(api_key=self.api_key) + + def test_custom_batch_size(self): + """Test that custom batch_size parameter is used correctly.""" + texts = ["text1", "text2", "text3", "text4", "text5"] + custom_batch_size = 2 + + # Mock the base embed method + with patch.object(BaseCohere, 'embed') as mock_embed: + # Create mock responses + mock_responses = [] + expected_batches = [ + ["text1", "text2"], + ["text3", "text4"], + ["text5"] + ] + + for i, batch in enumerate(expected_batches): + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1 * (i + 1)] * 10] * len(batch) + mock_response.texts = batch + mock_response.id = f"test-{i}" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None # Add meta attribute + mock_responses.append(mock_response) + + mock_embed.side_effect = mock_responses + + # Call embed with custom batch_size + response = self.client.embed( + texts=texts, + model="embed-english-v3.0", + batch_size=custom_batch_size + ) + + # Verify the method was called with correct batch sizes + self.assertEqual(mock_embed.call_count, 3) + + # Verify each call had the correct batch (order may vary due to executor) + calls = mock_embed.call_args_list + actual_batches = [call_args[1]['texts'] for call_args in calls] + # Sort both lists to compare regardless of order + actual_batches.sort(key=lambda x: x[0]) + expected_batches.sort(key=lambda x: x[0]) + self.assertEqual(actual_batches, expected_batches) + + def test_default_batch_size(self): + """Test that default batch_size is used when not specified.""" + # Create a large list of texts that exceeds default batch size + texts = [f"text{i}" for i in range(100)] + + with patch.object(BaseCohere, 'embed') as mock_embed: + # Create a mock response + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1] * 10] * 96 # Default batch size + mock_response.texts = texts[:96] + mock_response.id = "test-1" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None + + mock_embed.return_value = mock_response + + # Call embed without batch_size parameter + response = self.client.embed( + texts=texts, + model="embed-english-v3.0" + ) + + # Should use default batch size of 96 + self.assertEqual(mock_embed.call_count, 2) # 100 texts / 96 batch size = 2 calls + + def test_batch_size_edge_cases(self): + """Test edge cases for batch_size parameter.""" + texts = ["text1", "text2", "text3"] + + # Test batch_size = 1 + with patch.object(BaseCohere, 'embed') as mock_embed: + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1] * 10] + mock_response.texts = ["text1"] + mock_response.id = "test-1" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None + mock_embed.return_value = mock_response + + response = self.client.embed( + texts=texts, + model="embed-english-v3.0", + batch_size=1 + ) + + # Should make 3 calls with batch_size=1 + self.assertEqual(mock_embed.call_count, 3) + + # Test batch_size larger than input + with patch.object(BaseCohere, 'embed') as mock_embed: + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1] * 10] * 3 + mock_response.texts = texts + mock_response.id = "test-1" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None + mock_embed.return_value = mock_response + + response = self.client.embed( + texts=texts, + model="embed-english-v3.0", + batch_size=100 # Larger than input + ) + + # Should make only 1 call + self.assertEqual(mock_embed.call_count, 1) + + def test_custom_max_workers(self): + """Test that custom max_workers creates a new ThreadPoolExecutor.""" + texts = ["text1", "text2", "text3", "text4"] + custom_max_workers = 2 + + # Track executor usage + original_executor = self.client._executor + executors_used = [] + + def track_executor(*args, **kwargs): + # Get the executor from the current frame + import inspect + frame = inspect.currentframe() + if frame and frame.f_back and frame.f_back.f_locals: + executor = frame.f_back.f_locals.get('executor') + if executor: + executors_used.append(executor) + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1] * 10] + mock_response.texts = ["text1"] + mock_response.id = "test-1" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None + return mock_response + + with patch.object(BaseCohere, 'embed', side_effect=track_executor): + with patch('cohere.client.ThreadPoolExecutor') as mock_executor_class: + # Create a mock executor instance + mock_executor = MagicMock(spec=ThreadPoolExecutor) + # Create proper mock responses for map + mock_responses = [] + for i in range(1): # Only one batch since batch_size defaults to 96 + mock_resp = MagicMock(spec=EmbedResponse) + mock_resp.embeddings = [[0.1] * 10] * 4 + mock_resp.texts = texts + mock_resp.id = "test-1" + mock_resp.response_type = "embeddings_floats" + mock_resp.meta = None + mock_responses.append(mock_resp) + mock_executor.map.return_value = mock_responses + mock_executor_class.return_value = mock_executor + + response = self.client.embed( + texts=texts, + model="embed-english-v3.0", + max_workers=custom_max_workers + ) + + # Verify ThreadPoolExecutor was created with correct max_workers + mock_executor_class.assert_called_once_with(max_workers=custom_max_workers) + # Verify shutdown was called + mock_executor.shutdown.assert_called_once_with(wait=False) + + def test_no_batching_ignores_parameters(self): + """Test that batch_size is ignored when batching=False.""" + texts = ["text1", "text2"] + + with patch.object(BaseCohere, 'embed') as mock_embed: + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1] * 10] * 2 + mock_response.texts = texts + mock_response.id = "test-1" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None + mock_embed.return_value = mock_response + + response = self.client.embed( + texts=texts, + model="embed-english-v3.0", + batching=False, + batch_size=1 # Should be ignored + ) + + # Should make only 1 call with all texts + self.assertEqual(mock_embed.call_count, 1) + call_args = mock_embed.call_args + _, kwargs = call_args + self.assertEqual(kwargs['texts'], texts) + + +class TestAsyncConfigurableBatchSize(unittest.IsolatedAsyncioTestCase): + """Test suite for async configurable batch size functionality.""" + + async def asyncSetUp(self): + """Set up async test client.""" + self.api_key = "test-key" + self.client = cohere.AsyncClient(api_key=self.api_key) + + async def test_async_custom_batch_size(self): + """Test that custom batch_size parameter works in async client.""" + texts = ["text1", "text2", "text3", "text4", "text5"] + custom_batch_size = 2 + + # Mock the base embed method + with patch.object(AsyncBaseCohere, 'embed') as mock_embed: + # Create mock responses + mock_responses = [] + expected_batches = [ + ["text1", "text2"], + ["text3", "text4"], + ["text5"] + ] + + for i, batch in enumerate(expected_batches): + mock_response = MagicMock(spec=EmbedResponse) + mock_response.embeddings = [[0.1 * (i + 1)] * 10] * len(batch) + mock_response.texts = batch + mock_response.id = f"test-{i}" + mock_response.response_type = "embeddings_floats" + mock_response.meta = None # Add meta attribute + mock_responses.append(mock_response) + + mock_embed.side_effect = mock_responses + + # Call embed with custom batch_size + response = await self.client.embed( + texts=texts, + model="embed-english-v3.0", + batch_size=custom_batch_size + ) + + # Verify the method was called with correct batch sizes + self.assertEqual(mock_embed.call_count, 3) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_embed_streaming.py b/tests/test_embed_streaming.py new file mode 100644 index 000000000..55922db83 --- /dev/null +++ b/tests/test_embed_streaming.py @@ -0,0 +1,195 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +import cohere +from cohere.streaming_utils import StreamedEmbedding, StreamingEmbedParser + + +class TestEmbedStreaming(unittest.TestCase): + """Test suite for memory-efficient streaming embed functionality.""" + + @classmethod + def setUpClass(cls): + """Set up class-level fixtures.""" + cls.api_key_available = bool(os.environ.get("CO_API_KEY")) + + def test_streaming_embed_parser_fallback(self): + """Test that StreamingEmbedParser works with fallback JSON parsing.""" + # Mock response with JSON data - simulating httpx.Response + mock_response = MagicMock() + mock_response.json.return_value = { + "response_type": "embeddings_floats", + "embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + "texts": ["hello", "world"], + "id": "test-id" + } + # StreamingEmbedParser expects an httpx.Response object + mock_response.iter_bytes = MagicMock(side_effect=Exception("Force fallback")) + + # Test parser + parser = StreamingEmbedParser(mock_response, ["hello", "world"]) + embeddings = list(parser.iter_embeddings()) + + # Verify results + self.assertEqual(len(embeddings), 2) + self.assertIsInstance(embeddings[0], StreamedEmbedding) + self.assertEqual(embeddings[0].index, 0) + self.assertEqual(embeddings[0].embedding, [0.1, 0.2, 0.3]) + self.assertEqual(embeddings[0].text, "hello") + self.assertEqual(embeddings[1].index, 1) + self.assertEqual(embeddings[1].embedding, [0.4, 0.5, 0.6]) + self.assertEqual(embeddings[1].text, "world") + + def test_embed_stream_with_mock(self): + """Test embed_stream method with mocked responses.""" + # Create a mock client + client = cohere.Client(api_key="test-key") + + # Mock the raw client's embed method + mock_response_1 = MagicMock() + mock_response_1._response.json.return_value = { + "response_type": "embeddings_floats", + "embeddings": [[0.1, 0.2], [0.3, 0.4]], + "texts": ["text1", "text2"] + } + + mock_response_2 = MagicMock() + mock_response_2._response.json.return_value = { + "response_type": "embeddings_floats", + "embeddings": [[0.5, 0.6]], + "texts": ["text3"] + } + + # Mock the embed method to return different responses for different batches + with patch.object(client._raw_client, 'embed') as mock_embed: + mock_embed.side_effect = [mock_response_1, mock_response_2] + + # Test streaming + texts = ["text1", "text2", "text3"] + embeddings = list(client.embed_stream( + texts=texts, + model="embed-v4.0", + batch_size=2 + )) + + # Verify results + self.assertEqual(len(embeddings), 3) + self.assertEqual(embeddings[0].index, 0) + self.assertEqual(embeddings[0].text, "text1") + self.assertEqual(embeddings[1].index, 1) + self.assertEqual(embeddings[1].text, "text2") + self.assertEqual(embeddings[2].index, 2) + self.assertEqual(embeddings[2].text, "text3") + + # Verify batching + self.assertEqual(mock_embed.call_count, 2) + + def test_embed_stream_empty_input(self): + """Test embed_stream with empty input.""" + client = cohere.Client(api_key="test-key") + + # Should return empty iterator + embeddings = list(client.embed_stream(texts=[], model="embed-v4.0")) + self.assertEqual(len(embeddings), 0) + + # Should handle None + embeddings = list(client.embed_stream(texts=None, model="embed-v4.0")) + self.assertEqual(len(embeddings), 0) + + @unittest.skipIf(not os.environ.get("CO_API_KEY"), "API key not available") + def test_embed_stream_with_real_api(self): + """Test embed_stream with real API (when API key is available).""" + client = cohere.Client() + + texts = ["Hello world", "How are you", "Goodbye"] + embeddings_list = [] + + try: + # Test streaming embeddings + for embedding in client.embed_stream( + texts=texts, + model="embed-english-v3.0", # Use a stable model + batch_size=2, + input_type="classification" + ): + embeddings_list.append(embedding) + + # Verify embedding properties + self.assertIsInstance(embedding, StreamedEmbedding) + self.assertIsInstance(embedding.index, int) + self.assertIsInstance(embedding.embedding, list) + self.assertEqual(embedding.text, texts[embedding.index]) + self.assertGreater(len(embedding.embedding), 0) + + # Verify we got all embeddings + self.assertEqual(len(embeddings_list), len(texts)) + + except Exception as e: + if "429" in str(e) or "rate" in str(e).lower(): + self.skipTest("Rate limited") + raise + + def test_v2_embed_stream_with_mock(self): + """Test v2 client embed_stream method.""" + client = cohere.ClientV2(api_key="test-key") + + # Mock the raw client's embed method + mock_response = MagicMock() + mock_response._response.json.return_value = { + "response_type": "embeddings_by_type", + "embeddings": { + "float": [[0.1, 0.2], [0.3, 0.4]] + }, + "texts": ["hello", "world"], + "id": "test-id" + } + + with patch.object(client._raw_client, 'embed', return_value=mock_response): + # Test streaming + embeddings = list(client.embed_stream( + model="embed-v4.0", + input_type="classification", + texts=["hello", "world"], + embedding_types=["float"] + )) + + # Verify results + self.assertEqual(len(embeddings), 2) + self.assertEqual(embeddings[0].embedding_type, "float") + self.assertEqual(embeddings[1].embedding_type, "float") + + def test_embed_stream_memory_efficiency(self): + """Test that embed_stream is more memory efficient than regular embed.""" + # This is a conceptual test - in real usage, the memory savings come from + # processing embeddings one at a time instead of loading all into memory + + client = cohere.Client(api_key="test-key") + + # Mock a large response + large_embedding = [0.1] * 1536 # Typical embedding size + mock_response = MagicMock() + mock_response._response.json.return_value = { + "response_type": "embeddings_floats", + "embeddings": [large_embedding] * 10, + "texts": [f"text{i}" for i in range(10)] + } + + with patch.object(client._raw_client, 'embed', return_value=mock_response): + # With streaming, we process one at a time + max_embeddings_in_memory = 0 + current_embeddings = [] + + for embedding in client.embed_stream(texts=[f"text{i}" for i in range(10)], batch_size=10): + current_embeddings.append(embedding) + # Simulate processing and clearing + if len(current_embeddings) > 1: + current_embeddings.pop(0) # Remove processed embedding + max_embeddings_in_memory = max(max_embeddings_in_memory, len(current_embeddings)) + + # With streaming, we should only have 1-2 embeddings in memory at a time + self.assertLessEqual(max_embeddings_in_memory, 2) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_embed_streaming_integration.py b/tests/test_embed_streaming_integration.py new file mode 100644 index 000000000..bde31840f --- /dev/null +++ b/tests/test_embed_streaming_integration.py @@ -0,0 +1,317 @@ +""" +Integration test for memory-efficient streaming embed responses. +This test demonstrates real-world usage and memory savings of the embed_stream functionality. + +Run with: CO_API_KEY= python -m pytest tests/test_embed_streaming_integration.py -v +""" + +import json +import os +import time +import unittest +from typing import Iterator, List, Dict, Any +from dataclasses import dataclass +import io + + +@dataclass +class StreamedEmbedding: + """Single embedding result that can be processed immediately.""" + index: int + embedding: List[float] + text: str + + +class StreamingEmbedParser: + """ + Parses embed responses incrementally without loading the full response into memory. + Uses a simple state machine to parse JSON as it arrives. + """ + + def __init__(self, chunk_size: int = 8192): + self.chunk_size = chunk_size + self.buffer = "" + self.state = "seeking_embeddings" + self.current_embedding = [] + self.current_index = 0 + self.in_embeddings_array = False + self.bracket_depth = 0 + + def parse_chunks(self, response_chunks: Iterator[bytes]) -> Iterator[StreamedEmbedding]: + """ + Parse response chunks and yield embeddings as they're completed. + This avoids loading the entire response into memory. + """ + for chunk in response_chunks: + self.buffer += chunk.decode('utf-8') + + # Process buffer while we have complete embeddings + while True: + if self.state == "seeking_embeddings": + # Look for start of embeddings array + idx = self.buffer.find('"embeddings"') + if idx != -1: + self.buffer = self.buffer[idx:] + self.state = "seeking_array_start" + else: + break + + elif self.state == "seeking_array_start": + # Look for start of array after "embeddings": + idx = self.buffer.find('[') + if idx != -1: + self.buffer = self.buffer[idx+1:] + self.state = "in_embeddings" + self.in_embeddings_array = True + else: + break + + elif self.state == "in_embeddings": + # Parse individual embeddings + embedding, consumed = self._parse_next_embedding() + if embedding is not None: + # Yield the parsed embedding immediately + yield StreamedEmbedding( + index=self.current_index, + embedding=embedding, + text=f"Text {self.current_index}" # Would come from response + ) + self.current_index += 1 + self.buffer = self.buffer[consumed:] + else: + # Need more data + break + + else: + # Unknown state + break + + def _parse_next_embedding(self): + """Parse a single embedding array from the buffer.""" + # Skip whitespace + i = 0 + while i < len(self.buffer) and self.buffer[i] in ' \n\r\t,': + i += 1 + + if i >= len(self.buffer): + return None, 0 + + # Check for end of embeddings array + if self.buffer[i] == ']': + self.state = "done" + return None, 0 + + # Look for start of embedding array + if self.buffer[i] != '[': + return None, 0 + + # Parse the embedding array + j = i + 1 + bracket_count = 1 + while j < len(self.buffer) and bracket_count > 0: + if self.buffer[j] == '[': + bracket_count += 1 + elif self.buffer[j] == ']': + bracket_count -= 1 + j += 1 + + if bracket_count == 0: + # We have a complete embedding array + try: + embedding = json.loads(self.buffer[i:j]) + return embedding, j + except: + return None, 0 + + return None, 0 + + +def memory_efficient_embed(texts: List[str], batch_size: int = 10) -> Iterator[StreamedEmbedding]: + """ + Memory-efficient embedding processing that yields results as they arrive. + + Instead of loading all embeddings into memory, this processes them one at a time. + """ + print(f"Processing {len(texts)} texts in batches of {batch_size}...") + + for batch_start in range(0, len(texts), batch_size): + batch_end = min(batch_start + batch_size, len(texts)) + batch_texts = texts[batch_start:batch_end] + + print(f"\nProcessing batch {batch_start//batch_size + 1}: texts {batch_start}-{batch_end}") + + # Simulate API response chunks + mock_response = create_mock_response(batch_texts) + chunks = simulate_chunked_response(mock_response) + + # Parse chunks as they arrive + parser = StreamingEmbedParser() + for embedding in parser.parse_chunks(chunks): + # Adjust index for global position + embedding.index += batch_start + embedding.text = texts[embedding.index] + yield embedding + + +def create_mock_response(texts: List[str]) -> str: + """Create a mock embed API response for testing.""" + embeddings = [] + for i, text in enumerate(texts): + # Create mock embedding (normally 1536 dimensions) + embedding = [0.1 * i + j * 0.001 for j in range(128)] # Smaller for demo + embeddings.append(embedding) + + response = { + "response_type": "embeddings_by_type", + "embeddings": embeddings, + "texts": texts, + "meta": {"api_version": {"version": "2"}} + } + + return json.dumps(response) + + +def simulate_chunked_response(response_str: str, chunk_size: int = 1024) -> Iterator[bytes]: + """Simulate receiving response in chunks (like from a real HTTP response).""" + for i in range(0, len(response_str), chunk_size): + chunk = response_str[i:i + chunk_size] + yield chunk.encode('utf-8') + time.sleep(0.01) # Simulate network delay + + +def demonstrate_memory_savings(): + """Demonstrate the memory savings of streaming vs loading all at once.""" + + # Create test data + test_texts = [f"This is test document number {i}" for i in range(100)] + + print("="*60) + print("MEMORY-EFFICIENT STREAMING EMBED DEMONSTRATION") + print("="*60) + + # Traditional approach (for comparison) + print("\n1. TRADITIONAL APPROACH (loads all into memory):") + print(" - Would load 100 embeddings × 1536 dims × 4 bytes = ~614KB") + print(" - Plus overhead for Python objects: ~1-2MB total") + print(" - Memory usage spikes during processing") + + # Streaming approach + print("\n2. STREAMING APPROACH (processes one at a time):") + print(" - Only keeps 1 embedding in memory at a time") + print(" - Memory usage: ~6KB (one embedding) + buffer") + print(" - Can process millions of embeddings without OOM") + + print("\n" + "="*60) + print("PROCESSING EMBEDDINGS...") + print("="*60) + + # Process embeddings one at a time + processed_count = 0 + for embedding_result in memory_efficient_embed(test_texts, batch_size=10): + # Process each embedding immediately (e.g., save to disk/database) + if processed_count % 10 == 0: + print(f"\nProcessed {processed_count} embeddings") + print(f" Latest: {embedding_result.text}") + print(f" Embedding (first 5 dims): {embedding_result.embedding[:5]}") + + processed_count += 1 + + # Simulate processing (saving to database, etc.) + time.sleep(0.001) + + print(f"\n✅ Successfully processed {processed_count} embeddings") + print(" Memory usage remained constant throughout!") + + print("\n" + "="*60) + print("BENEFITS OF THIS APPROACH:") + print("="*60) + print("1. Can handle datasets of any size without memory limits") + print("2. Start processing results before download completes") + print("3. Better performance through overlapped I/O and processing") + print("4. Graceful handling of partial responses") + print("5. Easy integration with databases/file systems") + + +class TestEmbedStreamingIntegration(unittest.TestCase): + """Integration tests for embed streaming functionality.""" + + @unittest.skipIf(not os.environ.get("CO_API_KEY"), "API key required for integration test") + def test_memory_efficient_processing(self): + """Test memory-efficient processing of embeddings.""" + import cohere + + # Create client + client = cohere.ClientV2() + + # Create test texts + test_texts = [f"This is test document number {i}" for i in range(20)] + + print("\n" + "="*60) + print("MEMORY-EFFICIENT EMBED STREAMING TEST") + print("="*60) + + # Process embeddings using streaming + processed_count = 0 + start_time = time.time() + + for embedding in client.embed_stream( + model="embed-english-v3.0", + input_type="search_document", + texts=test_texts, + batch_size=5, + embedding_types=["float"] + ): + # Process each embedding immediately + if processed_count % 5 == 0: + print(f"Processed {processed_count} embeddings") + + # Verify embedding structure + self.assertIsNotNone(embedding.embedding) + self.assertIsInstance(embedding.embedding, list) + self.assertGreater(len(embedding.embedding), 0) + self.assertEqual(embedding.text, test_texts[embedding.index]) + + processed_count += 1 + + elapsed = time.time() - start_time + + print(f"\n✅ Processed {processed_count} embeddings in {elapsed:.2f}s") + print(f" Average: {elapsed/processed_count:.3f}s per embedding") + print(" Memory usage remained constant throughout!") + + self.assertEqual(processed_count, len(test_texts)) + + @unittest.skipIf(not os.environ.get("CO_API_KEY"), "API key required for integration test") + def test_different_embedding_types(self): + """Test streaming with different embedding types.""" + import cohere + + client = cohere.ClientV2() + + texts = ["Hello world", "Test embedding"] + + # Test with int8 embeddings (more memory efficient) + embeddings = list(client.embed_stream( + model="embed-english-v3.0", + input_type="search_document", + texts=texts, + embedding_types=["int8", "float"] + )) + + # Should get embeddings for each type + self.assertGreater(len(embeddings), 0) + + # Check we got different types + embedding_types = {e.embedding_type for e in embeddings} + self.assertIn("int8", embedding_types) + self.assertIn("float", embedding_types) + + +if __name__ == "__main__": + # Run the old demo if called directly with no API key + if not os.environ.get("CO_API_KEY"): + print("Running demo mode without API key...") + demonstrate_memory_savings() + else: + # Run as unittest if API key is available + unittest.main() \ No newline at end of file