diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 3c368abef..8640f378f 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -169,7 +169,7 @@ jobs: - name: Run tests run: poetry run python -m pytest tests/unit - check-linting: + code-format-check: runs-on: ubuntu-latest strategy: matrix: @@ -216,10 +216,10 @@ jobs: - name: Install library run: poetry install --no-interaction #---------------------------------------------- - # black the code + # Run Ruff format check #---------------------------------------------- - - name: Black - run: poetry run black --check src + - name: Ruff format check + run: poetry run ruff format --check check-types: runs-on: ubuntu-latest diff --git a/examples/custom_cred_provider.py b/examples/custom_cred_provider.py index 67945f23c..6a596d76e 100644 --- a/examples/custom_cred_provider.py +++ b/examples/custom_cred_provider.py @@ -21,7 +21,6 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), credentials_provider=creds, ) as connection: - for x in range(1, 5): cursor = connection.cursor() cursor.execute("SELECT 1+1") diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 5bc6c6793..57fcf12ba 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -1,6 +1,7 @@ """ Test for SEA asynchronous query execution functionality. """ + import os import sys import logging diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index a200d97d3..6c0d773e7 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -1,6 +1,7 @@ """ Test for SEA metadata functionality. """ + import os import sys import logging diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py index 516c1bbb8..01d8eadf8 100644 --- a/examples/experimental/tests/test_sea_session.py +++ b/examples/experimental/tests/test_sea_session.py @@ -1,6 +1,7 @@ """ Test for SEA session management functionality. """ + import os import sys import logging diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 4e12d5aa4..266efa9cf 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -1,6 +1,7 @@ """ Test for SEA synchronous query execution functionality. """ + import os import sys import logging diff --git a/examples/insert_data.py b/examples/insert_data.py index 053ed158c..bc845b370 100644 --- a/examples/insert_data.py +++ b/examples/insert_data.py @@ -6,7 +6,6 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), ) as connection: - with connection.cursor() as cursor: cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)") diff --git a/examples/interactive_oauth.py b/examples/interactive_oauth.py index 8dbc8c47c..a78012df2 100644 --- a/examples/interactive_oauth.py +++ b/examples/interactive_oauth.py @@ -17,7 +17,6 @@ server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path=os.getenv("DATABRICKS_HTTP_PATH"), ) as connection: - for x in range(1, 100): cursor = connection.cursor() cursor.execute("SELECT 1+1") diff --git a/examples/persistent_oauth.py b/examples/persistent_oauth.py index 1a2eded28..22892f64b 100644 --- a/examples/persistent_oauth.py +++ b/examples/persistent_oauth.py @@ -51,7 +51,6 @@ def read(self, hostname: str) -> Optional[OAuthToken]: auth_type="databricks-oauth", experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json"), ) as connection: - for x in range(1, 100): cursor = connection.cursor() cursor.execute("SELECT 1+1") diff --git a/examples/proxy_authentication.py b/examples/proxy_authentication.py index 8547336b3..517ba692e 100644 --- a/examples/proxy_authentication.py +++ b/examples/proxy_authentication.py @@ -20,24 +20,26 @@ # Configure logging to see proxy activity logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) # Uncomment for detailed debugging (shows HTTP requests/responses) # logging.getLogger("urllib3").setLevel(logging.DEBUG) # logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG) + def check_proxy_environment(): """Check if proxy environment variables are configured.""" - proxy_vars = ['HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy'] - configured_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)} - + proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy"] + configured_proxies = { + var: os.environ.get(var) for var in proxy_vars if os.environ.get(var) + } + if configured_proxies: print("✓ Proxy environment variables found:") for var, value in configured_proxies.items(): # Hide credentials in output for security - safe_value = value.split('@')[-1] if '@' in value else value + safe_value = value.split("@")[-1] if "@" in value else value print(f" {var}: {safe_value}") return True else: @@ -45,97 +47,101 @@ def check_proxy_environment(): print(" Set HTTP_PROXY and/or HTTPS_PROXY if using a proxy") return False + def test_connection(connection_params, test_name): """Test a database connection with given parameters.""" print(f"\n--- Testing {test_name} ---") - + try: with sql.connect(**connection_params) as connection: print("✓ Successfully connected!") - + with connection.cursor() as cursor: # Test basic query - cursor.execute("SELECT current_user() as user, current_database() as database") + cursor.execute( + "SELECT current_user() as user, current_database() as database" + ) result = cursor.fetchone() print(f"✓ Connected as user: {result.user}") print(f"✓ Default database: {result.database}") - + # Test a simple computation cursor.execute("SELECT 1 + 1 as result") result = cursor.fetchone() print(f"✓ Query result: 1 + 1 = {result.result}") - + return True - + except Exception as e: print(f"✗ Connection failed: {e}") return False + def main(): print("Databricks SQL Connector - Proxy Authentication Examples") print("=" * 60) - + # Check proxy configuration has_proxy = check_proxy_environment() - + # Get Databricks connection parameters - server_hostname = os.environ.get('DATABRICKS_SERVER_HOSTNAME') - http_path = os.environ.get('DATABRICKS_HTTP_PATH') - access_token = os.environ.get('DATABRICKS_TOKEN') - + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + if not all([server_hostname, http_path, access_token]): print("\n✗ Missing required environment variables:") print(" DATABRICKS_SERVER_HOSTNAME") - print(" DATABRICKS_HTTP_PATH") + print(" DATABRICKS_HTTP_PATH") print(" DATABRICKS_TOKEN") return 1 - + print(f"\nConnecting to: {server_hostname}") - + # Base connection parameters base_params = { - 'server_hostname': server_hostname, - 'http_path': http_path, - 'access_token': access_token + "server_hostname": server_hostname, + "http_path": http_path, + "access_token": access_token, } - + success_count = 0 total_tests = 0 - + # Test 1: Default proxy behavior (no _proxy_auth_method specified) # This uses basic auth if credentials are in proxy URL, otherwise no auth - print("\n" + "="*60) + print("\n" + "=" * 60) print("Test 1: Default Proxy Behavior") print("Uses basic authentication if credentials are in proxy URL") total_tests += 1 if test_connection(base_params, "Default Proxy Behavior"): success_count += 1 - + # Test 2: Explicit basic authentication - print("\n" + "="*60) + print("\n" + "=" * 60) print("Test 2: Explicit Basic Authentication") print("Explicitly requests basic authentication (same as default)") total_tests += 1 basic_params = base_params.copy() - basic_params['_proxy_auth_method'] = 'basic' + basic_params["_proxy_auth_method"] = "basic" if test_connection(basic_params, "Basic Proxy Authentication"): success_count += 1 - + # Test 3: Kerberos/Negotiate authentication - print("\n" + "="*60) + print("\n" + "=" * 60) print("Test 3: Kerberos/Negotiate Authentication") print("Uses Kerberos tickets for proxy authentication") print("Note: Requires valid Kerberos tickets (run 'kinit' first)") total_tests += 1 kerberos_params = base_params.copy() - kerberos_params['_proxy_auth_method'] = 'negotiate' + kerberos_params["_proxy_auth_method"] = "negotiate" if test_connection(kerberos_params, "Kerberos Proxy Authentication"): success_count += 1 - + # Summary - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Summary: {success_count}/{total_tests} tests passed") - + if success_count == total_tests: print("✓ All proxy authentication methods working!") return 0 @@ -149,5 +155,6 @@ def main(): print("Consider checking your proxy configuration") return 1 + if __name__ == "__main__": exit(main()) diff --git a/examples/query_async_execute.py b/examples/query_async_execute.py index de4712fe3..ba649b738 100644 --- a/examples/query_async_execute.py +++ b/examples/query_async_execute.py @@ -7,7 +7,6 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), ) as connection: - with connection.cursor() as cursor: long_running_query = """ SELECT COUNT(*) FROM RANGE(10000 * 16) x diff --git a/examples/query_cancel.py b/examples/query_cancel.py index b67fc0857..e2d8a9e19 100644 --- a/examples/query_cancel.py +++ b/examples/query_cancel.py @@ -10,7 +10,6 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), ) as connection: - with connection.cursor() as cursor: def execute_really_long_query(): diff --git a/examples/query_execute.py b/examples/query_execute.py index 38d2f17a8..57e0bf0ff 100644 --- a/examples/query_execute.py +++ b/examples/query_execute.py @@ -6,7 +6,6 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), ) as connection: - with connection.cursor() as cursor: cursor.execute("SELECT * FROM default.diamonds LIMIT 2") result = cursor.fetchall() diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py index f615d082c..93da19267 100644 --- a/examples/query_tags_example.py +++ b/examples/query_tags_example.py @@ -17,14 +17,13 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), session_configuration={ - 'QUERY_TAGS': 'team:engineering,test:query-tags', - 'ansi_mode': False - } + "QUERY_TAGS": "team:engineering,test:query-tags", + "ansi_mode": False, + }, ) as connection: - with connection.cursor() as cursor: cursor.execute("SELECT 1") result = cursor.fetchone() print(f" Result: {result[0]}") -print("\n=== Query Tags Example Complete ===") \ No newline at end of file +print("\n=== Query Tags Example Complete ===") diff --git a/examples/set_user_agent.py b/examples/set_user_agent.py index 093f03bd5..acfaebe03 100644 --- a/examples/set_user_agent.py +++ b/examples/set_user_agent.py @@ -7,7 +7,6 @@ access_token=os.getenv("DATABRICKS_TOKEN"), user_agent_entry="ExamplePartnerTag", ) as connection: - with connection.cursor() as cursor: cursor.execute("SELECT * FROM default.diamonds LIMIT 2") result = cursor.fetchall() diff --git a/examples/staging_ingestion.py b/examples/staging_ingestion.py index a55be4778..fdb4b7c8e 100644 --- a/examples/staging_ingestion.py +++ b/examples/staging_ingestion.py @@ -41,7 +41,6 @@ _complete_path = os.path.realpath(FILEPATH) if not os.path.exists(_complete_path): - # It's easiest to save a file in the same directory as this script. But any path to a file will work. raise Exception( "You need to set FILEPATH in this script to a file that actually exists." @@ -56,9 +55,7 @@ access_token=os.getenv("DATABRICKS_TOKEN"), staging_allowed_local_path=staging_allowed_local_path, ) as connection: - with connection.cursor() as cursor: - # Ingestion commands are executed like any other SQL. # Here's a sample PUT query. You can remove OVERWRITE at the end to avoid silently overwriting data. query = f"PUT '{_complete_path}' INTO 'stage://tmp/{INGESTION_USER}/pysql_examples/demo.csv' OVERWRITE" diff --git a/examples/streaming_put.py b/examples/streaming_put.py index 4e7697099..ab4c8c9bb 100644 --- a/examples/streaming_put.py +++ b/examples/streaming_put.py @@ -14,21 +14,20 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), ) as connection: - with connection.cursor() as cursor: # Create a simple data stream data = b"Hello, streaming world!" stream = io.BytesIO(data) - + # Get catalog, schema, and volume from environment variables catalog = os.getenv("DATABRICKS_CATALOG") schema = os.getenv("DATABRICKS_SCHEMA") volume = os.getenv("DATABRICKS_VOLUME") - + # Upload to Unity Catalog volume cursor.execute( f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE", - input_stream=stream + input_stream=stream, ) - - print("File uploaded successfully!") \ No newline at end of file + + print("File uploaded successfully!") diff --git a/examples/transactions.py b/examples/transactions.py index 6f58dbd2d..85d05c595 100644 --- a/examples/transactions.py +++ b/examples/transactions.py @@ -6,7 +6,6 @@ http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), ) as connection: - # Disable autocommit to use explicit transactions connection.autocommit = False diff --git a/examples/v3_retries_query_execute.py b/examples/v3_retries_query_execute.py index aaab47d11..c938ba8b8 100644 --- a/examples/v3_retries_query_execute.py +++ b/examples/v3_retries_query_execute.py @@ -36,7 +36,6 @@ _retry_dangerous_codes=[502, 400], _retry_max_redirects=2, ) as connection: - with connection.cursor() as cursor: cursor.execute("SELECT * FROM default.diamonds LIMIT 2") result = cursor.fetchall() diff --git a/poetry.lock b/poetry.lock index 7d0845a58..e7af46dca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -15,42 +15,6 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -[[package]] -name = "black" -version = "22.12.0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.7" -groups = ["dev"] -files = [ - {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, - {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, - {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"}, - {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"}, - {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"}, - {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"}, - {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"}, - {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"}, - {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"}, - {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"}, - {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"}, - {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} -typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "certifi" version = "2025.1.31" @@ -246,21 +210,6 @@ files = [ {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"}, ] -[[package]] -name = "click" -version = "8.1.8" -description = "Composable command line interface toolkit" -optional = false -python-versions = ">=3.7" -groups = ["dev"] -files = [ - {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, - {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - [[package]] name = "colorama" version = "0.4.6" @@ -268,7 +217,7 @@ description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" groups = ["dev"] -markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" +markers = "sys_platform == \"win32\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -1268,18 +1217,6 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "platformdirs" version = "4.3.6" @@ -1746,6 +1683,35 @@ cryptography = ">=1.3" pyspnego = {version = "*", extras = ["kerberos"]} requests = ">=1.1.0" +[[package]] +name = "ruff" +version = "0.14.13" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "ruff-0.14.13-py3-none-linux_armv6l.whl", hash = "sha256:76f62c62cd37c276cb03a275b198c7c15bd1d60c989f944db08a8c1c2dbec18b"}, + {file = "ruff-0.14.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:914a8023ece0528d5cc33f5a684f5f38199bbb566a04815c2c211d8f40b5d0ed"}, + {file = "ruff-0.14.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d24899478c35ebfa730597a4a775d430ad0d5631b8647a3ab368c29b7e7bd063"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9aaf3870f14d925bbaf18b8a2347ee0ae7d95a2e490e4d4aea6813ed15ebc80e"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac5b7f63dd3b27cc811850f5ffd8fff845b00ad70e60b043aabf8d6ecc304e09"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d2b1097750d90ba82ce4ba676e85230a0ed694178ca5e61aa9b459970b3eb9"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7d0bf87705acbbcb8d4c24b2d77fbb73d40210a95c3903b443cd9e30824a5032"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3eb5da8e2c9e9f13431032fdcbe7681de9ceda5835efee3269417c13f1fed5c"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:642442b42957093811cd8d2140dfadd19c7417030a7a68cf8d51fcdd5f217427"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4acdf009f32b46f6e8864af19cbf6841eaaed8638e65c8dac845aea0d703c841"}, + {file = "ruff-0.14.13-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:591a7f68860ea4e003917d19b5c4f5ac39ff558f162dc753a2c5de897fd5502c"}, + {file = "ruff-0.14.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:774c77e841cc6e046fc3e91623ce0903d1cd07e3a36b1a9fe79b81dab3de506b"}, + {file = "ruff-0.14.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:61f4e40077a1248436772bb6512db5fc4457fe4c49e7a94ea7c5088655dd21ae"}, + {file = "ruff-0.14.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6d02f1428357fae9e98ac7aa94b7e966fd24151088510d32cf6f902d6c09235e"}, + {file = "ruff-0.14.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e399341472ce15237be0c0ae5fbceca4b04cd9bebab1a2b2c979e015455d8f0c"}, + {file = "ruff-0.14.13-py3-none-win32.whl", hash = "sha256:ef720f529aec113968b45dfdb838ac8934e519711da53a0456038a0efecbd680"}, + {file = "ruff-0.14.13-py3-none-win_amd64.whl", hash = "sha256:6070bd026e409734b9257e03e3ef18c6e1a216f0435c6751d7a8ec69cb59abef"}, + {file = "ruff-0.14.13-py3-none-win_arm64.whl", hash = "sha256:7ab819e14f1ad9fe39f246cfcc435880ef7a9390d81a2b6ac7e01039083dd247"}, + {file = "ruff-0.14.13.tar.gz", hash = "sha256:83cd6c0763190784b99650a20fec7633c59f6ebe41c5cc9d45ee42749563ad47"}, +] + [[package]] name = "six" version = "1.17.0" @@ -1969,4 +1935,4 @@ pyarrow = ["pyarrow", "pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "ec311bf26ec866de2f427bcdf4ec69ceed721bfd70edfae3aba1ac12882a09d6" +content-hash = "db62280ec7965e3a156857f667076e7853ce586b66ec4419df60aa45f6b0138b" diff --git a/pyproject.toml b/pyproject.toml index 8a635588c..7f7b020a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ pyarrow = ["pyarrow"] pytest = "^7.1.2" mypy = "^1.10.1" pylint = ">=2.12.0" -black = "^22.3.0" pytest-dotenv = "^0.5.2" pytest-cov = "^4.0.0" pytest-xdist = "^3.0.0" @@ -49,6 +48,7 @@ numpy = [ { version = ">=1.16.6", python = ">=3.8,<3.11" }, { version = ">=1.23.4", python = ">=3.11" }, ] +ruff = "^0.14.13" [tool.poetry.urls] "Homepage" = "https://github.com/databricks/databricks-sql-python" @@ -62,8 +62,22 @@ build-backend = "poetry.core.masonry.api" ignore_missing_imports = "true" exclude = ['ttypes\.py$', 'TCLIService\.py$'] -[tool.black] -exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' +[tool.ruff.format] +exclude = [ + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".tox", + ".venv", + ".svn", + "_build", + "buck-out", + "build", + "dist", + "**/thrift_api/**" +] [tool.pytest.ini_options] markers = [ diff --git a/scripts/dependency_manager.py b/scripts/dependency_manager.py index 15e119841..38e1fbe42 100644 --- a/scripts/dependency_manager.py +++ b/scripts/dependency_manager.py @@ -1,7 +1,7 @@ """ Dependency version management for testing. Generates requirements files for min and default dependency versions. -For min versions, creates flexible constraints (e.g., >=1.2.5,<1.3.0) to allow +For min versions, creates flexible constraints (e.g., >=1.2.5,<1.3.0) to allow compatible patch updates instead of pinning exact versions. """ @@ -12,28 +12,29 @@ from packaging.requirements import Requirement from pathlib import Path + class DependencyManager: def __init__(self, pyproject_path="pyproject.toml"): self.pyproject_path = Path(pyproject_path) self.dependencies = self._load_dependencies() - + # Map of packages that need specific transitive dependency constraints when downgraded self.transitive_dependencies = { - 'pandas': { + "pandas": { # When pandas is downgraded to 1.x, ensure numpy compatibility - 'numpy': { - 'min_constraint': '>=1.16.5,<2.0.0', # pandas 1.x works with numpy 1.x - 'applies_when': lambda version: version.startswith('1.') + "numpy": { + "min_constraint": ">=1.16.5,<2.0.0", # pandas 1.x works with numpy 1.x + "applies_when": lambda version: version.startswith("1."), } } } - + def _load_dependencies(self): """Load dependencies from pyproject.toml""" - with open(self.pyproject_path, 'r') as f: + with open(self.pyproject_path, "r") as f: pyproject = toml.load(f) - return pyproject['tool']['poetry']['dependencies'] - + return pyproject["tool"]["poetry"]["dependencies"] + def _parse_constraint(self, name, constraint): """Parse a dependency constraint into version info""" if isinstance(constraint, str): @@ -45,66 +46,69 @@ def _parse_constraint(self, name, constraint): # Find the constraint that matches the current Python version for item in constraint: - if 'python' in item: - python_spec = item['python'] + if "python" in item: + python_spec = item["python"] # Parse the Python version specifier spec_set = SpecifierSet(python_spec) # Check if current Python version matches this constraint if current_version in spec_set: - version = item['version'] - is_optional = item.get('optional', False) + version = item["version"] + is_optional = item.get("optional", False) return version, is_optional # Fallback to first constraint if no Python version match first_constraint = constraint[0] - version = first_constraint['version'] - is_optional = first_constraint.get('optional', False) + version = first_constraint["version"] + is_optional = first_constraint.get("optional", False) return version, is_optional elif isinstance(constraint, dict): - if 'version' in constraint: - return constraint['version'], constraint.get('optional', False) + if "version" in constraint: + return constraint["version"], constraint.get("optional", False) return None, False - + def _extract_versions_from_specifier(self, spec_set_str): """Extract minimum version from a specifier set""" try: # Handle caret (^) and tilde (~) constraints that packaging doesn't support - if spec_set_str.startswith('^'): + if spec_set_str.startswith("^"): # ^1.2.3 means >=1.2.3, <2.0.0 min_version = spec_set_str[1:] # Remove ^ return min_version, None - elif spec_set_str.startswith('~'): + elif spec_set_str.startswith("~"): # ~1.2.3 means >=1.2.3, <1.3.0 min_version = spec_set_str[1:] # Remove ~ return min_version, None - + spec_set = SpecifierSet(spec_set_str) min_version = None - + for spec in spec_set: - if spec.operator in ('>=', '=='): + if spec.operator in (">=", "=="): min_version = spec.version break - + return min_version, None except Exception as e: - print(f"Warning: Could not parse constraint '{spec_set_str}': {e}", file=sys.stderr) + print( + f"Warning: Could not parse constraint '{spec_set_str}': {e}", + file=sys.stderr, + ) return None, None - + def _create_flexible_minimum_constraint(self, package_name, min_version): """Create a flexible minimum constraint that allows compatible updates""" try: # Split version into parts - version_parts = min_version.split('.') - + version_parts = min_version.split(".") + if len(version_parts) >= 2: major = version_parts[0] minor = version_parts[1] - + # Special handling for packages that commonly have conflicts # For these packages, use wider constraints to allow more compatibility - if package_name in ['requests', 'urllib3', 'pandas']: + if package_name in ["requests", "urllib3", "pandas"]: # Use wider constraint: >=min_version,=2.18.1,<3.0.0 next_major = int(major) + 1 @@ -119,133 +123,162 @@ def _create_flexible_minimum_constraint(self, package_name, min_version): else: # If version doesn't have minor version, just use exact version return f"{package_name}=={min_version}" - + except (ValueError, IndexError) as e: - print(f"Warning: Could not create flexible constraint for {package_name}=={min_version}: {e}", file=sys.stderr) + print( + f"Warning: Could not create flexible constraint for {package_name}=={min_version}: {e}", + file=sys.stderr, + ) # Fallback to exact version return f"{package_name}=={min_version}" - + def _get_transitive_dependencies(self, package_name, version, version_type): """Get transitive dependencies that need specific constraints based on the main package version""" transitive_reqs = [] - + if package_name in self.transitive_dependencies: transitive_deps = self.transitive_dependencies[package_name] - + for dep_name, dep_config in transitive_deps.items(): # Check if this transitive dependency applies for this version - if dep_config['applies_when'](version): + if dep_config["applies_when"](version): if version_type == "min": # Use the predefined constraint for minimum versions - constraint = dep_config['min_constraint'] + constraint = dep_config["min_constraint"] transitive_reqs.append(f"{dep_name}{constraint}") # For default version_type, we don't add transitive deps as Poetry handles them - + return transitive_reqs - + def generate_requirements(self, version_type="min", include_optional=False): """ Generate requirements for specified version type. - + Args: version_type: "min" or "default" include_optional: Whether to include optional dependencies """ requirements = [] transitive_requirements = [] - + for name, constraint in self.dependencies.items(): - if name == 'python': + if name == "python": continue - + version_constraint, is_optional = self._parse_constraint(name, constraint) if not version_constraint: continue - + if is_optional and not include_optional: continue - + if version_type == "default": # For default, just use the constraint as-is (let poetry resolve) requirements.append(f"{name}{version_constraint}") elif version_type == "min": - min_version, _ = self._extract_versions_from_specifier(version_constraint) + min_version, _ = self._extract_versions_from_specifier( + version_constraint + ) if min_version: # Create flexible constraint that allows patch updates for compatibility - flexible_constraint = self._create_flexible_minimum_constraint(name, min_version) + flexible_constraint = self._create_flexible_minimum_constraint( + name, min_version + ) requirements.append(flexible_constraint) - + # Check if this package needs specific transitive dependencies - transitive_deps = self._get_transitive_dependencies(name, min_version, version_type) + transitive_deps = self._get_transitive_dependencies( + name, min_version, version_type + ) transitive_requirements.extend(transitive_deps) - + # Combine main requirements with transitive requirements all_requirements = requirements + transitive_requirements - + # Remove duplicates (prefer main requirements over transitive ones) seen_packages = set() final_requirements = [] - + # First add main requirements for req in requirements: package_name = Requirement(req).name seen_packages.add(package_name) final_requirements.append(req) - + # Then add transitive requirements that don't conflict for req in transitive_requirements: package_name = Requirement(req).name if package_name not in seen_packages: final_requirements.append(req) - + return final_requirements - - def write_requirements_file(self, filename, version_type="min", include_optional=False): + def write_requirements_file( + self, filename, version_type="min", include_optional=False + ): """Write requirements to a file""" requirements = self.generate_requirements(version_type, include_optional) - - with open(filename, 'w') as f: + + with open(filename, "w") as f: if version_type == "min": - f.write(f"# Minimum compatible dependency versions generated from pyproject.toml\n") - f.write(f"# Uses flexible constraints to resolve compatibility conflicts:\n") - f.write(f"# - Common packages (requests, urllib3, pandas): >=min,=min,=min, str: - ... + def auth_type(self) -> str: ... @abc.abstractmethod - def __call__(self, *args, **kwargs) -> HeaderFactory: - ... + def __call__(self, *args, **kwargs) -> HeaderFactory: ... # Private API: this is an evolving interface and it will change in the future. @@ -231,9 +229,9 @@ def header_factory() -> Dict[str, str]: } if self.azure_workspace_resource_id: - headers[ - self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER - ] = self.azure_workspace_resource_id + headers[self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER] = ( + self.azure_workspace_resource_id + ) return headers diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 2becfb4fb..5264f03f4 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -112,7 +112,6 @@ def startRetryTimer(self): self.retry_policy and self.retry_policy.start_retry_timer() def open(self): - # self.__pool replaces the self.__http used by the original THttpClient _pool_kwargs = {"maxsize": self.max_connections} @@ -159,7 +158,6 @@ def isOpen(self): return self.__resp is not None def flush(self): - # Pull data out of buffer that will be sent in this request data = self.__wbuf.getvalue() self.__wbuf = BytesIO() diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index f75b904fb..b7cfaafa0 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -111,9 +111,9 @@ def add_headers(self, request_headers: Dict[str, str]): """Add authentication headers to the request.""" if self._cached_token and not self._cached_token.is_expired(): - request_headers[ - "Authorization" - ] = f"{self._cached_token.token_type} {self._cached_token.access_token}" + request_headers["Authorization"] = ( + f"{self._cached_token.token_type} {self._cached_token.access_token}" + ) return # Get the external headers first to check if we need token federation diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 69c6dfbe2..5061a1572 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -35,7 +35,7 @@ def _convert_decimal( # Apply scale (quantize to specific number of decimal places) if specified quantizer = None if scale is not None: - quantizer = decimal.Decimal(f'0.{"0" * scale}') + quantizer = decimal.Decimal(f"0.{'0' * scale}") # Apply precision (total number of significant digits) if specified context = None diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index edee02bfa..7a3060f46 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -675,7 +675,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti num_rows, ) = convert_column_based_set_to_arrow_table(t_row_set.columns, description) elif t_row_set.arrowBatches is not None: - (arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table( + ( + arrow_table, + num_rows, + ) = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a0215aae5..120b33980 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -433,8 +433,9 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): if self.open: logger.debug( - "Closing unclosed connection for session " - "{}".format(self.get_session_id_hex()) + "Closing unclosed connection for session {}".format( + self.get_session_id_hex() + ) ) try: self._close(close_cursors=False) @@ -1144,7 +1145,6 @@ def _handle_staging_put( self._handle_staging_http_response(r) def _handle_staging_http_response(self, r): - # fmt: off # HTTP status codes OK = 200 diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index f4770f3c4..7501fcf4d 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -3,6 +3,7 @@ logger = logging.getLogger(__name__) + ### PEP-249 Mandated ### # https://peps.python.org/pep-0249/#exceptions class Error(Exception): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 523fcc1dc..953f9710b 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -599,13 +599,13 @@ def initialize_telemetry_client( batch_size=batch_size, client_context=client_context, ) - TelemetryClientFactory._clients[ - host_url - ] = _TelemetryClientHolder(client) + TelemetryClientFactory._clients[host_url] = ( + _TelemetryClientHolder(client) + ) else: - TelemetryClientFactory._clients[ - host_url - ] = _TelemetryClientHolder(NoopTelemetryClient()) + TelemetryClientFactory._clients[host_url] = ( + _TelemetryClientHolder(NoopTelemetryClient()) + ) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index e188ef577..8f843c841 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -121,7 +121,7 @@ class Row(tuple): def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": if args and kwargs: - raise ValueError("Can not use both args " "and kwargs to create Row") + raise ValueError("Can not use both args and kwargs to create Row") if kwargs: # create row objects row = tuple.__new__(cls, list(kwargs.values())) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 043183ac2..84d118089 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -761,7 +761,6 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": def convert_to_assigned_datatypes_in_column_table(column_table, description): - converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": @@ -877,7 +876,7 @@ def _create_python_tuple(t_col_value_wrapper): def concat_table_chunks( - table_chunks: List[Union["pyarrow.Table", ColumnTable]] + table_chunks: List[Union["pyarrow.Table", ColumnTable]], ) -> Union["pyarrow.Table", ColumnTable]: if len(table_chunks) == 0: return table_chunks diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py index 3f0fdc05d..4cb9252d8 100644 --- a/tests/e2e/common/core_tests.py +++ b/tests/e2e/common/core_tests.py @@ -4,13 +4,11 @@ TypeFailure = namedtuple( "TypeFailure", - "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue,actualValue,actualType,description,conf", ) ResultFailure = namedtuple( "ResultFailure", - "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue,actualValue,actualType,description,conf", ) ExecFailure = namedtuple( "ExecFailure", @@ -81,8 +79,9 @@ def run_tests_on_queries(self, default_conf): if failures: self.fail( - "Failed testing result set with Arrow. " - "Failed queries: {}".format("\n\n".join([str(f) for f in failures])) + "Failed testing result set with Arrow. Failed queries: {}".format( + "\n\n".join([str(f) for f in failures]) + ) ) def run_query(self, cursor, query, columnType, rowValueType, answer, conf): diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index dd7c56996..f57d70ba5 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -125,9 +125,7 @@ def test_long_running_query(self, extra_params): FROM RANGE({scale}) x JOIN RANGE({scale0}) y ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" - """.format( - scale=scale_factor * scale0, scale0=scale0 - ) + """.format(scale=scale_factor * scale0, scale0=scale0) ) (n,) = cursor.fetchone() diff --git a/tests/e2e/common/predicates.py b/tests/e2e/common/predicates.py index 61de69fd3..b2dabff1a 100644 --- a/tests/e2e/common/predicates.py +++ b/tests/e2e/common/predicates.py @@ -31,7 +31,6 @@ def test_some_pyhive_v1_stuff(): def is_endpoint_test(cli_args=None): - # Currently only supporting tests against DBSQL Endpoints # So we don't read `is_endpoint_test` from the CLI args return True diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b2350bd98..da99da178 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -94,16 +94,18 @@ def _test_retry_disabled_with_message( class SimpleHttpResponse: """A simple HTTP response mock that works with both urllib3 v1.x and v2.x""" - - def __init__(self, status: int, headers: dict, redirect_location: Optional[str] = None): + + def __init__( + self, status: int, headers: dict, redirect_location: Optional[str] = None + ): # Import the correct HTTP message type that urllib3 v1.x expects try: from http.client import HTTPMessage except ImportError: from httplib import HTTPMessage - + self.status = status - # Create proper HTTPMessage for urllib3 v1.x compatibility + # Create proper HTTPMessage for urllib3 v1.x compatibility self.headers = HTTPMessage() for key, value in headers.items(): self.headers[key] = str(value) @@ -116,41 +118,41 @@ def __init__(self, status: int, headers: dict, redirect_location: Optional[str] self._body = b"" self._fp = io.BytesIO(self._body) self._url = "https://example.com" - + def get_redirect_location(self, *args, **kwargs): """Return the redirect location or False""" return False if self._redirect_location is None else self._redirect_location - + def read(self, amt=None): """Mock read method for file-like behavior""" return self._body - + def close(self): """Mock close method""" pass - + def drain_conn(self): """Mock drain_conn method for urllib3 v2.x""" pass - + def isclosed(self): """Mock isclosed method for urllib3 v1.x""" return False - + def release_conn(self): """Mock release_conn method for thrift HTTP client""" pass - + @property def data(self): """Mock data property for urllib3 v2.x""" return self._body - + @property def url(self): """Mock url property""" return self._url - + @url.setter def url(self, value): """Mock url setter""" @@ -162,7 +164,7 @@ def mocked_server_response( status: int = 200, headers: dict = {}, redirect_location: Optional[str] = None ): """Context manager for patching urllib3 responses with version compatibility""" - + mock_response = SimpleHttpResponse(status, headers, redirect_location) with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock: @@ -188,7 +190,7 @@ def mock_sequential_server_responses(responses: List[dict]): SimpleHttpResponse( status=resp["status"], headers=resp["headers"], - redirect_location=resp["redirect_location"] + redirect_location=resp["redirect_location"], ) for resp in responses ] @@ -539,12 +541,12 @@ def test_3xx_redirect_codes_are_not_retried( max_redirects, expected_call_count = 1, 1 # Code 302 is a redirect, but 3xx codes are not retried per policy - # Note: We don't set redirect_location because that would cause urllib3 v2.x + # Note: We don't set redirect_location because that would cause urllib3 v2.x # to follow redirects internally, bypassing our retry policy test - with mocked_server_response( - status=302, redirect_location=None - ) as mock_obj: - with pytest.raises(RequestError): # Should get RequestError, not MaxRetryError + with mocked_server_response(status=302, redirect_location=None) as mock_obj: + with pytest.raises( + RequestError + ): # Should get RequestError, not MaxRetryError with self.connection( extra_params={ **extra_params, @@ -575,12 +577,12 @@ def test_3xx_codes_not_retried_regardless_of_max_redirects_setting( according to the DatabricksRetryPolicy regardless of redirect settings. """ # Code 302 is a redirect, but 3xx codes are not retried per policy - # Note: We don't set redirect_location because that would cause urllib3 v2.x + # Note: We don't set redirect_location because that would cause urllib3 v2.x # to follow redirects internally, bypassing our retry policy test - with mocked_server_response( - status=302, redirect_location=None - ) as mock_obj: - with pytest.raises(RequestError): # Should get RequestError, not MaxRetryError + with mocked_server_response(status=302, redirect_location=None) as mock_obj: + with pytest.raises( + RequestError + ): # Should get RequestError, not MaxRetryError with self.connection( extra_params={ **extra_params, @@ -599,9 +601,7 @@ def test_3xx_codes_not_retried_regardless_of_max_redirects_setting( {"use_sea": True}, ], ) - def test_3xx_codes_stop_request_immediately_no_retry_attempts( - self, extra_params - ): + def test_3xx_codes_stop_request_immediately_no_retry_attempts(self, extra_params): # Since 3xx codes are not retried per policy, we only ever see the first 302 response responses = [ {"status": 302, "headers": {}, "redirect_location": "/foo.bar"}, diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 73aa0a113..3790aa2b3 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -44,7 +44,6 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): with self.connection( extra_params={"staging_allowed_local_path": temp_path} ) as conn: - cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' OVERWRITE" cursor.execute(query) @@ -80,9 +79,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises( - Error, match="too many 404 error responses" - ): + with pytest.raises(Error, match="too many 404 error responses"): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -115,7 +112,6 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path( def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path( self, ingestion_user ): - fh, temp_path = tempfile.mkstemp() original_text = "hello world!".encode("utf-8") diff --git a/tests/e2e/common/streaming_put_tests.py b/tests/e2e/common/streaming_put_tests.py index 83da10fd3..7fb86fe6d 100644 --- a/tests/e2e/common/streaming_put_tests.py +++ b/tests/e2e/common/streaming_put_tests.py @@ -15,36 +15,38 @@ class PySQLStreamingPutTestSuiteMixin: def test_streaming_put_basic(self, catalog, schema): """Test basic streaming PUT functionality.""" - + # Create test data test_data = b"Hello, streaming world! This is test data." filename = "streaming_put_test.txt" file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}" - + try: with self.connection() as conn: with conn.cursor() as cursor: self._cleanup_test_file(file_path) - + with io.BytesIO(test_data) as stream: cursor.execute( f"PUT '__input_stream__' INTO '{file_path}'", - input_stream=stream + input_stream=stream, ) - + # Verify file exists cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'") files = cursor.fetchall() - + # Check if our file is in the list file_paths = [row[0] for row in files] - assert file_path in file_paths, f"File {file_path} not found in {file_paths}" + assert file_path in file_paths, ( + f"File {file_path} not found in {file_paths}" + ) finally: self._cleanup_test_file(file_path) - + def test_streaming_put_missing_stream(self, catalog, schema): """Test that missing stream raises appropriate error.""" - + with self.connection() as conn: with conn.cursor() as cursor: # Test without providing stream @@ -52,14 +54,16 @@ def test_streaming_put_missing_stream(self, catalog, schema): cursor.execute( f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'" # Note: No input_stream parameter - ) + ) def _cleanup_test_file(self, file_path): """Clean up a test file if it exists.""" try: - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: with conn.cursor() as cursor: cursor.execute(f"REMOVE '{file_path}'") logger.info("Successfully cleaned up test file: %s", file_path) except Exception as e: - logger.error("Cleanup failed for %s: %s", file_path, e) \ No newline at end of file + logger.error("Cleanup failed for %s: %s", file_path, e) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 93e63bd28..8e9e717c1 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -43,7 +43,6 @@ def test_uc_volume_life_cycle(self, catalog, schema): with self.connection( extra_params={"staging_allowed_local_path": temp_path} ) as conn: - cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) @@ -80,9 +79,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises( - Error, match="too many 404 error responses" - ): + with pytest.raises(Error, match="too many 404 error responses"): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -115,7 +112,6 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path( def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( self, catalog, schema ): - fh, temp_path = tempfile.mkstemp() original_text = "hello world!".encode("utf-8") diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py index f9df0f377..6360dd6d6 100644 --- a/tests/e2e/test_circuit_breaker.py +++ b/tests/e2e/test_circuit_breaker.py @@ -143,7 +143,9 @@ def mock_request(*args, **kwargs): # Wait for circuit to open (async telemetry may take time) assert wait_for_circuit_state( circuit_breaker, [STATE_OPEN], timeout=5 - ), f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}" + ), ( + f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}" + ) # Circuit should be OPEN after rate-limit failures assert circuit_breaker.current_state == STATE_OPEN @@ -262,7 +264,9 @@ def mock_conditional_request(*args, **kwargs): # Wait for full recovery assert wait_for_circuit_state( circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5 - ), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}" + ), ( + f"Circuit didn't fully recover, state: {circuit_breaker.current_state}" + ) if __name__ == "__main__": diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index d075a5670..ac68db09f 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -11,7 +11,7 @@ class TestComplexTypes(PySQLPytestTestCase): def table_fixture(self, connection_details): self.arguments = connection_details.copy() """A pytest fixture that creates a table with a complex type, inserts a record, yields, and then drops the table""" - + table_name = f"pysql_test_complex_types_table_{str(uuid4()).replace('-', '_')}" self.table_name = table_name @@ -64,9 +64,7 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): """Confirms the return types of a complex type field when reading as arrow""" with self.cursor() as cursor: - result = cursor.execute( - f"SELECT * FROM {table_fixture} LIMIT 1" - ).fetchone() + result = cursor.execute(f"SELECT * FROM {table_fixture} LIMIT 1").fetchone() assert isinstance(result[field], expected_type) @@ -86,8 +84,6 @@ def test_read_complex_types_as_string(self, field, table_fixture): with self.cursor( extra_params={"_use_arrow_native_complex_types": False} ) as cursor: - result = cursor.execute( - f"SELECT * FROM {table_fixture} LIMIT 1" - ).fetchone() + result = cursor.execute(f"SELECT * FROM {table_fixture} LIMIT 1").fetchone() assert isinstance(result[field], str) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index bed348c2c..563ca6006 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -74,10 +74,11 @@ def callback_wrapper(self_client, future, sent_count): captured_futures.append(future) original_callback(self_client, future, sent_count) - with patch.object( - TelemetryClient, "_send_telemetry", send_telemetry_wrapper - ), patch.object( - TelemetryClient, "_telemetry_request_callback", callback_wrapper + with ( + patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), + patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ), ): def execute_query_worker(thread_id): @@ -124,9 +125,13 @@ def execute_query_worker(thread_id): response = future.result() # Check status using urllib3 method (response.status instead of response.raise_for_status()) if response.status >= 400: - raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}") + raise Exception( + f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}" + ) # Parse JSON using urllib3 method (response.data.decode() instead of response.json()) - response_data = json.loads(response.data.decode()) if response.data else {} + response_data = ( + json.loads(response.data.decode()) if response.data else {} + ) captured_responses.append(response_data) except Exception as e: captured_exceptions.append(e) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index e04e348c9..775979e84 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -65,9 +65,9 @@ for name in test_loader.getTestCaseNames(DecimalTestsMixin): if name.startswith("test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( - fn - ) + decorated = skipUnless( + pysql_supports_arrow(), "Decimal tests need arrow support" + )(fn) setattr(DecimalTestsMixin, name, decorated) @@ -195,7 +195,6 @@ class TestPySQLAsyncQueriesSuite(PySQLPytestTestCase): ], ) def test_execute_async__long_running(self, extra_params): - long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'" with self.cursor(extra_params) as cursor: cursor.execute_async(long_running_query) @@ -901,7 +900,13 @@ def test_timestamps_arrow(self): ) def test_multi_timestamps_arrow(self, extra_params): with self.cursor( - {"session_configuration": {"ansi_mode": False, "query_tags": "test:multi-timestamps,driver:python"}, **extra_params} + { + "session_configuration": { + "ansi_mode": False, + "query_tags": "test:multi-timestamps,driver:python", + }, + **extra_params, + } ) as cursor: query, expected = self.multi_query() expected = [ @@ -1015,9 +1020,9 @@ def test_row_limit_with_smaller_result(self): rows = cursor.fetchall() # Check if all rows are returned (not limited by row_limit) - assert ( - len(rows) == expected_rows - ), f"Expected {expected_rows} rows, got {len(rows)}" + assert len(rows) == expected_rows, ( + f"Expected {expected_rows} rows, got {len(rows)}" + ) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_row_limit_with_arrow_larger_result(self): @@ -1029,9 +1034,9 @@ def test_row_limit_with_arrow_larger_result(self): arrow_table = cursor.fetchall_arrow() # Check if the number of rows in the arrow table is limited to row_limit - assert ( - arrow_table.num_rows == row_limit - ), f"Expected {row_limit} rows, got {arrow_table.num_rows}" + assert arrow_table.num_rows == row_limit, ( + f"Expected {row_limit} rows, got {arrow_table.num_rows}" + ) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_row_limit_with_arrow_smaller_result(self): @@ -1044,9 +1049,9 @@ def test_row_limit_with_arrow_smaller_result(self): arrow_table = cursor.fetchall_arrow() # Check if all rows are returned (not limited by row_limit) - assert ( - arrow_table.num_rows == expected_rows - ), f"Expected {expected_rows} rows, got {arrow_table.num_rows}" + assert arrow_table.num_rows == expected_rows, ( + f"Expected {expected_rows} rows, got {arrow_table.num_rows}" + ) # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 7370eea93..13c2a6608 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -157,9 +157,11 @@ def inline_table(self, connection_details): Note that this fixture doesn't clean itself up. So the table will remain in the schema for use by subsequent test runs. """ - + # Generate unique table name to avoid conflicts in parallel execution - table_name = f"pysql_e2e_inline_param_test_table_{str(uuid4()).replace('-', '_')}" + table_name = ( + f"pysql_e2e_inline_param_test_table_{str(uuid4()).replace('-', '_')}" + ) self.inline_table_name = table_name self._create_inline_table(table_name) @@ -187,11 +189,13 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column) :paramstyle: This is a no-op but is included to make the test-code easier to read. """ - if not hasattr(self, 'inline_table_name'): - table_name = f"pysql_e2e_inline_param_test_table_{str(uuid4()).replace('-', '_')}" + if not hasattr(self, "inline_table_name"): + table_name = ( + f"pysql_e2e_inline_param_test_table_{str(uuid4()).replace('-', '_')}" + ) self.inline_table_name = table_name self._create_inline_table(table_name) - + table_name = self.inline_table_name INSERT_QUERY = f"INSERT INTO {table_name} (`{target_column}`) VALUES (%(p)s)" SELECT_QUERY = f"SELECT {target_column} `col` FROM {table_name} LIMIT 1" @@ -412,13 +416,13 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) ): cursor.execute("SELECT %(p)s", parameters={"p": 1}) if use_inline_params is True: - assert ( - "Consider using native parameters." in caplog.text - ), "Log message should be suppressed" + assert "Consider using native parameters." in caplog.text, ( + "Log message should be suppressed" + ) elif use_inline_params == "silent": - assert ( - "Consider using native parameters." not in caplog.text - ), "Log message should not be supressed" + assert "Consider using native parameters." not in caplog.text, ( + "Log message should not be supressed" + ) def test_positional_native_params_with_defaults(self): query = "SELECT ? col" diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py index 0a57edd3c..50f502e34 100644 --- a/tests/e2e/test_telemetry_e2e.py +++ b/tests/e2e/test_telemetry_e2e.py @@ -1,6 +1,7 @@ """ E2E test for telemetry - verifies telemetry behavior with different scenarios """ + import time import threading import logging @@ -61,6 +62,7 @@ def telemetry_setup_teardown(self): # Clear feature flags cache to prevent state leakage between tests from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + with FeatureFlagsContextFactory._lock: FeatureFlagsContextFactory._context_map.clear() if FeatureFlagsContextFactory._executor: @@ -97,12 +99,21 @@ def assert_system_config(self, event): assert sys_config is not None # Check all required fields are non-empty - for field in ['driver_name', 'driver_version', 'os_name', 'os_version', - 'os_arch', 'runtime_name', 'runtime_version', 'runtime_vendor', - 'locale_name', 'char_set_encoding']: + for field in [ + "driver_name", + "driver_version", + "os_name", + "os_version", + "os_arch", + "runtime_name", + "runtime_version", + "runtime_vendor", + "locale_name", + "char_set_encoding", + ]: value = getattr(sys_config, field) assert value and len(value) > 0, f"{field} should not be None or empty" - + assert sys_config.driver_name == "Databricks SQL Python Connector" def assert_connection_params(self, event, expected_http_path=None): @@ -112,10 +123,10 @@ def assert_connection_params(self, event, expected_http_path=None): assert conn_params.http_path assert conn_params.host_info is not None assert conn_params.auth_mech is not None - + if expected_http_path: assert conn_params.http_path == expected_http_path - + if conn_params.socket_timeout is not None: assert conn_params.socket_timeout > 0 @@ -126,7 +137,7 @@ def assert_statement_execution(self, event): assert sql_op.statement_type is not None assert sql_op.execution_result is not None assert hasattr(sql_op, "retry_count") - + if sql_op.retry_count is not None: assert sql_op.retry_count >= 0 @@ -139,28 +150,34 @@ def assert_error_info(self, event, expected_error_name=None): assert error_info is not None assert error_info.error_name and len(error_info.error_name) > 0 assert error_info.stack_trace and len(error_info.stack_trace) > 0 - + if expected_error_name: assert error_info.error_name == expected_error_name def verify_events(self, captured_events, captured_futures, expected_count): """Common verification for event count and HTTP responses""" if expected_count == 0: - assert len(captured_events) == 0, f"Expected 0 events, got {len(captured_events)}" - assert len(captured_futures) == 0, f"Expected 0 responses, got {len(captured_futures)}" + assert len(captured_events) == 0, ( + f"Expected 0 events, got {len(captured_events)}" + ) + assert len(captured_futures) == 0, ( + f"Expected 0 responses, got {len(captured_futures)}" + ) else: - assert len(captured_events) == expected_count, \ + assert len(captured_events) == expected_count, ( f"Expected {expected_count} events, got {len(captured_events)}" + ) time.sleep(2) done, _ = wait(captured_futures, timeout=10) - assert len(done) == expected_count, \ + assert len(done) == expected_count, ( f"Expected {expected_count} responses, got {len(done)}" - + ) + for future in done: response = future.result() assert 200 <= response.status < 300 - + # Assert common fields for all events for event in captured_events: self.assert_system_config(event) @@ -168,21 +185,34 @@ def verify_events(self, captured_events, captured_futures, expected_count): # ==================== PARAMETERIZED TESTS ==================== - @pytest.mark.parametrize("enable_telemetry,force_enable,expected_count,test_id", [ - (True, False, 2, "enable_on_force_off"), - (False, True, 2, "enable_off_force_on"), - (False, False, 0, "both_off"), - (None, None, 2, "default_behavior"), - ]) - def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, - force_enable, expected_count, test_id): + @pytest.mark.parametrize( + "enable_telemetry,force_enable,expected_count,test_id", + [ + (True, False, 2, "enable_on_force_off"), + (False, True, 2, "enable_off_force_on"), + (False, False, 0, "both_off"), + (None, None, 2, "default_behavior"), + ], + ) + def test_telemetry_flags( + self, + telemetry_interceptors, + enable_telemetry, + force_enable, + expected_count, + test_id, + ): """Test telemetry behavior with different flag combinations""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ + captured_events, captured_futures, export_wrapper, callback_wrapper = ( telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - + ) + + with ( + patch.object(TelemetryClient, "_export_event", export_wrapper), + patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ), + ): extra_params = {"telemetry_batch_size": 1} if enable_telemetry is not None: extra_params["enable_telemetry"] = enable_telemetry @@ -197,27 +227,36 @@ def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, # Give time for async telemetry submission after connection closes time.sleep(0.5) self.verify_events(captured_events, captured_futures, expected_count) - + # Assert statement execution on latency event (if events exist) if expected_count > 0: self.assert_statement_execution(captured_events[-1]) - @pytest.mark.parametrize("query,expected_error", [ - ("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"), - ("SELECT * FROM non_existent_table_xyz_12345", None), - ]) + @pytest.mark.parametrize( + "query,expected_error", + [ + ("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"), + ("SELECT * FROM non_existent_table_xyz_12345", None), + ], + ) def test_sql_errors(self, telemetry_interceptors, query, expected_error): """Test telemetry captures error information for different SQL errors""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ + captured_events, captured_futures, export_wrapper, callback_wrapper = ( telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - }) as conn: + ) + + with ( + patch.object(TelemetryClient, "_export_event", export_wrapper), + patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ), + ): + with self.connection( + extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + } + ) as conn: with conn.cursor() as cursor: with pytest.raises(Exception): cursor.execute(query) @@ -229,8 +268,9 @@ def test_sql_errors(self, telemetry_interceptors, query, expected_error): assert len(captured_events) >= 1 # Find event with error_info - error_event = next((e for e in captured_events - if e.entry.sql_driver_log.error_info), None) + error_event = next( + (e for e in captured_events if e.entry.sql_driver_log.error_info), None + ) assert error_event is not None self.assert_system_config(error_event) @@ -239,16 +279,22 @@ def test_sql_errors(self, telemetry_interceptors, query, expected_error): def test_metadata_operation(self, telemetry_interceptors): """Test telemetry for metadata operations (getCatalogs)""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ + captured_events, captured_futures, export_wrapper, callback_wrapper = ( telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - }) as conn: + ) + + with ( + patch.object(TelemetryClient, "_export_event", export_wrapper), + patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ), + ): + with self.connection( + extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + } + ) as conn: with conn.cursor() as cursor: catalogs = cursor.catalogs() catalogs.fetchall() @@ -263,17 +309,23 @@ def test_metadata_operation(self, telemetry_interceptors): def test_direct_results(self, telemetry_interceptors): """Test telemetry with direct results (use_cloud_fetch=False)""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ + captured_events, captured_futures, export_wrapper, callback_wrapper = ( telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - "use_cloud_fetch": False, - }) as conn: + ) + + with ( + patch.object(TelemetryClient, "_export_event", export_wrapper), + patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ), + ): + with self.connection( + extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": False, + } + ) as conn: with conn.cursor() as cursor: cursor.execute("SELECT 100") result = cursor.fetchall() @@ -286,24 +338,32 @@ def test_direct_results(self, telemetry_interceptors): for event in captured_events: self.assert_system_config(event) self.assert_connection_params(event, self.arguments["http_path"]) - + self.assert_statement_execution(captured_events[-1]) - @pytest.mark.parametrize("close_type", [ - "context_manager", - "explicit_cursor", - "explicit_connection", - "implicit_fetchall", - ]) - def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, - close_type): + @pytest.mark.parametrize( + "close_type", + [ + "context_manager", + "explicit_cursor", + "explicit_connection", + "implicit_fetchall", + ], + ) + def test_cloudfetch_with_different_close_patterns( + self, telemetry_interceptors, close_type + ): """Test telemetry with cloud fetch using different resource closing patterns""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ + captured_events, captured_futures, export_wrapper, callback_wrapper = ( telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - + ) + + with ( + patch.object(TelemetryClient, "_export_event", export_wrapper), + patch.object( + TelemetryClient, "_telemetry_request_callback", callback_wrapper + ), + ): if close_type == "explicit_connection": # Test explicit connection close conn = sql.connect( @@ -319,24 +379,26 @@ def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, conn.close() else: # Other patterns use connection context manager - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - "use_cloud_fetch": True, - }) as conn: + with self.connection( + extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": True, + } + ) as conn: if close_type == "context_manager": with conn.cursor() as cursor: cursor.execute("SELECT * FROM range(1000)") result = cursor.fetchall() assert len(result) == 1000 - + elif close_type == "explicit_cursor": cursor = conn.cursor() cursor.execute("SELECT * FROM range(1000)") result = cursor.fetchall() assert len(result) == 1000 cursor.close() - + elif close_type == "implicit_fetchall": cursor = conn.cursor() cursor.execute("SELECT * FROM range(1000)") @@ -350,5 +412,5 @@ def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, for event in captured_events: self.assert_system_config(event) self.assert_connection_params(event, self.arguments["http_path"]) - + self.assert_statement_execution(captured_events[-1]) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py index d4f6a790a..94f9d81f1 100644 --- a/tests/e2e/test_transactions.py +++ b/tests/e2e/test_transactions.py @@ -136,16 +136,16 @@ def _cleanup(self): def test_default_autocommit_is_true(self): """Test that new connection defaults to autocommit=true.""" - assert ( - self.connection.autocommit is True - ), "New connection should have autocommit=true by default" + assert self.connection.autocommit is True, ( + "New connection should have autocommit=true by default" + ) def test_set_autocommit_to_false(self): """Test successfully setting autocommit to false.""" self.connection.autocommit = False - assert ( - self.connection.autocommit is False - ), "autocommit should be false after setting to false" + assert self.connection.autocommit is False, ( + "autocommit should be false after setting to false" + ) def test_set_autocommit_to_true(self): """Test successfully setting autocommit back to true.""" @@ -155,9 +155,9 @@ def test_set_autocommit_to_true(self): # Then enable self.connection.autocommit = True - assert ( - self.connection.autocommit is True - ), "autocommit should be true after setting to true" + assert self.connection.autocommit is True, ( + "autocommit should be true after setting to true" + ) # ==================== COMMIT TESTS ==================== @@ -536,9 +536,9 @@ def test_set_autocommit_during_active_transaction(self): # Verify error message mentions autocommit or active transaction error_msg = str(exc_info.value).lower() - assert ( - "autocommit" in error_msg or "active transaction" in error_msg - ), "Error should mention autocommit or active transaction" + assert "autocommit" in error_msg or "active transaction" in error_msg, ( + "Error should mention autocommit or active transaction" + ) # Cleanup - rollback the transaction self.connection.rollback() @@ -576,9 +576,9 @@ def test_rollback_without_active_transaction_is_safe(self): def test_get_transaction_isolation_returns_repeatable_read(self): """Test that get_transaction_isolation() returns REPEATABLE_READ.""" isolation_level = self.connection.get_transaction_isolation() - assert ( - isolation_level == "REPEATABLE_READ" - ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" + assert isolation_level == "REPEATABLE_READ", ( + "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" + ) def test_set_transaction_isolation_accepts_repeatable_read(self): """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py index 14be3aa3d..5ba71fbfe 100644 --- a/tests/e2e/test_variant_types.py +++ b/tests/e2e/test_variant_types.py @@ -54,15 +54,15 @@ def test_variant_type_detection(self, variant_table): cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0") # Verify column types in description - assert ( - cursor.description[0][1] == "int" - ), "Integer column type not correctly identified" - assert ( - cursor.description[1][1] == "variant" - ), "VARIANT column type not correctly identified" - assert ( - cursor.description[2][1] == "string" - ), "String column type not correctly identified" + assert cursor.description[0][1] == "int", ( + "Integer column type not correctly identified" + ) + assert cursor.description[1][1] == "variant", ( + "VARIANT column type not correctly identified" + ) + assert cursor.description[2][1] == "string", ( + "String column type not correctly identified" + ) def test_variant_data_retrieval(self, variant_table): """Test that VARIANT data is properly retrieved and can be accessed as JSON""" @@ -72,9 +72,9 @@ def test_variant_data_retrieval(self, variant_table): # First row should have a JSON object json_obj = rows[0][1] - assert isinstance( - json_obj, str - ), "VARIANT column should be returned as string" + assert isinstance(json_obj, str), ( + "VARIANT column should be returned as string" + ) parsed = json.loads(json_obj) assert parsed.get("name") == "John" @@ -82,9 +82,9 @@ def test_variant_data_retrieval(self, variant_table): # Second row should have a JSON array json_array = rows[1][1] - assert isinstance( - json_array, str - ), "VARIANT array should be returned as string" + assert isinstance(json_array, str), ( + "VARIANT array should be returned as string" + ) # Parsing to verify it's valid JSON array parsed_array = json.loads(json_array) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index d1b941208..6d334542e 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -163,10 +165,14 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") - self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider") + self.assertEqual( + type(auth_provider.external_provider).__name__, "ExternalAuthProvider" + ) headers = {} auth_provider.add_headers(headers) @@ -181,7 +187,9 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_use_cert_as_auth": use_cert_as_auth, } mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -191,7 +199,9 @@ def test_get_python_sql_connector_basic_auth(self): } mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) + get_python_sql_connector_auth_provider( + "foo.cloud.databricks.com", mock_http_client, **kwargs + ) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -200,12 +210,18 @@ def test_get_python_sql_connector_basic_auth(self): def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client + ) self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") - self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider") + self.assertEqual( + type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider" + ) - self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID) + self.assertEqual( + auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID + ) class TestClientCredentialsTokenSource: @@ -264,16 +280,16 @@ def test_no_token_refresh__when_token_is_not_expired( def test_get_token_success(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with the expected format mock_response = MagicMock() mock_response.status = 200 mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}' - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + token = token_source.get_token() # Assert @@ -284,16 +300,16 @@ def test_get_token_success(self, token_source, http_response): def test_get_token_failure(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with error mock_response = MagicMock() mock_response.status = 400 mock_response.data.decode.return_value = "Bad Request" - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) @@ -311,7 +327,6 @@ def credential_provider(self): ) def test_provider_credentials(self, credential_provider): - test_token = Token("access_token", "Bearer", "refresh_token") with patch.object( diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py index 1e02556d9..c859ec1c3 100644 --- a/tests/unit/test_circuit_breaker_manager.py +++ b/tests/unit/test_circuit_breaker_manager.py @@ -133,12 +133,15 @@ def successful_func(): assert breaker.current_state in ["closed", "half-open", "open"] - @pytest.mark.parametrize("old_state,new_state", [ - ("closed", "open"), - ("open", "half-open"), - ("half-open", "closed"), - ("closed", "half-open"), - ]) + @pytest.mark.parametrize( + "old_state,new_state", + [ + ("closed", "open"), + ("open", "half-open"), + ("half-open", "closed"), + ("closed", "half-open"), + ], + ) def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): """Test circuit breaker state listener logs all state transitions.""" from databricks.sql.telemetry.circuit_breaker_manager import ( @@ -155,6 +158,8 @@ def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): mock_new_state = Mock() mock_new_state.name = new_state - with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: + with patch( + "databricks.sql.telemetry.circuit_breaker_manager.logger" + ) as mock_logger: listener.state_change(mock_cb, mock_old_state, mock_new_state) mock_logger.debug.assert_called() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5b6991931..a151b9edd 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -509,7 +509,7 @@ def test_column_name_api(self): expected_values = [["val1", 321, 52.32], ["val2", 2321, 252.32]] - for (row, expected) in zip(data, expected_values): + for row, expected in zip(data, expected_values): self.assertEqual(row.first_col, expected[0]) self.assertEqual(row.second_col, expected[1]) self.assertEqual(row.third_col, expected[2]) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 0c3fc7103..83f4f0d24 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,29 +13,31 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): - def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + def create_queue( + self, schema_bytes=None, result_links=None, description=None, **kwargs + ): """Helper method to create ThriftCloudFetchQueue with sensible defaults""" # Set up defaults for commonly used parameters defaults = { - 'max_download_threads': 10, - 'ssl_options': SSLOptions(), - 'session_id_hex': Mock(), - 'statement_id': Mock(), - 'chunk_id': 0, - 'start_row_offset': 0, - 'lz4_compressed': True, + "max_download_threads": 10, + "ssl_options": SSLOptions(), + "session_id_hex": Mock(), + "statement_id": Mock(), + "chunk_id": 0, + "start_row_offset": 0, + "lz4_compressed": True, } - + # Override defaults with any provided kwargs defaults.update(kwargs) - + mock_http_client = MagicMock() return utils.ThriftCloudFetchQueue( schema_bytes=schema_bytes or MagicMock(), result_links=result_links or [], description=description or [], http_client=mock_http_client, - **defaults + **defaults, ) def create_result_link( @@ -198,7 +200,12 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + description = [ + ("col0", "uint32"), + ("col1", "uint32"), + ("col2", "uint32"), + ("col3", "uint32"), + ] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None @@ -277,7 +284,12 @@ def test_remaining_rows_multiple_tables_fully_returned( def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + description = [ + ("col0", "uint32"), + ("col1", "uint32"), + ("col2", "uint32"), + ("col3", "uint32"), + ] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 00b1b849a..547f6ce35 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time): self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) # Patch the log metrics method to avoid division by zero - with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): + with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -160,11 +160,19 @@ def test_run_compressed_successful(self, mock_time): result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" # Setup mock HTTP response using helper method - self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + self._setup_mock_http_response( + mock_http_client, status=200, data=compressed_bytes + ) # Mock the decompression method and log metrics to avoid issues - with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ - patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): + with ( + patch.object( + downloader.ResultSetDownloadHandler, + "_decompress_data", + return_value=file_bytes, + ), + patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"), + ): d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_param_escaper.py b/tests/unit/test_param_escaper.py index 9b6b9c246..67331b75a 100644 --- a/tests/unit/test_param_escaper.py +++ b/tests/unit/test_param_escaper.py @@ -229,16 +229,15 @@ class TestInlineToNativeTransformer(object): ), ( "query with doubled wildcards", - "select 1 where " ' like "%%"', + 'select 1 where like "%%"', {"param": None}, - "select 1 where " ' like "%%"', + 'select 1 where like "%%"', ), ), ) def test_transformer( self, label: str, query: str, params: Dict[str, Any], expected: str ): - _params = [ dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items() diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 26a898cb8..354547d62 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,7 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter - "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -197,7 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", - "query_tags": "team:marketing,dashboard:abc123", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -655,9 +655,9 @@ def test_filter_session_configuration(self): # Verify all returned values are strings for key, value in result.items(): - assert isinstance( - value, str - ), f"Value for key '{key}' is not a string: {type(value)}" + assert isinstance(value, str), ( + f"Value for key '{key}' is not a string: {type(value)}" + ) # Verify specific conversions expected_result = { @@ -700,9 +700,9 @@ def test_filter_session_configuration(self): # Verify all values are strings in case insensitive test for key, value in result.items(): - assert isinstance( - value, str - ), f"Value for key '{key}' is not a string: {type(value)}" + assert isinstance(value, str), ( + f"Value for key '{key}' is not a string: {type(value)}" + ) def test_results_message_to_execute_response_is_staging_operation(self, sea_client): """Test that is_staging_operation is correctly set from manifest.is_volume_operation.""" diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 6471cb4fd..b97ebf845 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -221,9 +221,10 @@ def test_build_queue_arrow_stream( mock_http_client = MagicMock() - with patch( - "databricks.sql.backend.sea.queue.ResultFileDownloadManager" - ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): + with ( + patch("databricks.sql.backend.sea.queue.ResultFileDownloadManager"), + patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None), + ): queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, manifest=arrow_manifest, @@ -520,7 +521,7 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1d70ec4c4..3bee6a4b2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -155,7 +155,10 @@ def test_socket_timeout_passthrough(self, mock_client_class): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): - mock_session_config = {"ANSI_MODE": "FALSE", "QUERY_TAGS": "team:engineering,project:data-pipeline"} + mock_session_config = { + "ANSI_MODE": "FALSE", + "QUERY_TAGS": "team:engineering,project:data-pipeline", + } databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 86f06aa8a..d59cfd15e 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -14,7 +14,11 @@ FeatureFlagsContextFactory, FeatureFlagsContext, ) -from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType +from databricks.sql.telemetry.models.enums import ( + AuthMech, + AuthFlow, + DatabricksClientType, +) from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverConnectionParameters, @@ -490,7 +494,7 @@ class TestTelemetryEventModels: def test_host_details_serialization(self): """Test HostDetails model serialization.""" host = HostDetails(host_url="test-host.com", port=443) - + # Test JSON string generation json_str = host.to_json() assert isinstance(json_str, str) @@ -503,7 +507,7 @@ def test_driver_connection_parameters_all_fields(self): host_info = HostDetails(host_url="workspace.databricks.com", port=443) proxy_info = HostDetails(host_url="proxy.company.com", port=8080) cf_proxy_info = HostDetails(host_url="cf-proxy.company.com", port=8080) - + params = DriverConnectionParameters( http_path="/sql/1.0/warehouses/abc123", mode=DatabricksClientType.SEA, @@ -532,11 +536,11 @@ def test_driver_connection_parameters_all_fields(self): allowed_volume_ingestion_paths="/Volumes/catalog/schema/volume", query_tags="team:engineering,project:telemetry", ) - + # Serialize to JSON and parse back json_str = params.to_json() json_dict = json.loads(json_str) - + # Verify all new fields are in JSON assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" assert json_dict["mode"] == "SEA" @@ -544,7 +548,10 @@ def test_driver_connection_parameters_all_fields(self): assert json_dict["auth_mech"] == "OAUTH" assert json_dict["auth_flow"] == "BROWSER_BASED_AUTHENTICATION" assert json_dict["socket_timeout"] == 30000 - assert json_dict["azure_workspace_resource_id"] == "/subscriptions/test/resourceGroups/test" + assert ( + json_dict["azure_workspace_resource_id"] + == "/subscriptions/test/resourceGroups/test" + ) assert json_dict["azure_tenant_id"] == "tenant-123" assert json_dict["use_proxy"] is True assert json_dict["use_system_proxy"] is True @@ -562,28 +569,31 @@ def test_driver_connection_parameters_all_fields(self): assert json_dict["async_poll_interval_millis"] == 2000 assert json_dict["support_many_parameters"] is True assert json_dict["enable_complex_datatype_support"] is True - assert json_dict["allowed_volume_ingestion_paths"] == "/Volumes/catalog/schema/volume" + assert ( + json_dict["allowed_volume_ingestion_paths"] + == "/Volumes/catalog/schema/volume" + ) assert json_dict["query_tags"] == "team:engineering,project:telemetry" def test_driver_connection_parameters_minimal_fields(self): """Test DriverConnectionParameters with only required fields.""" host_info = HostDetails(host_url="workspace.databricks.com", port=443) - + params = DriverConnectionParameters( http_path="/sql/1.0/warehouses/abc123", mode=DatabricksClientType.THRIFT, host_info=host_info, ) - + # Note: to_json() filters out None values, so we need to check asdict for complete structure json_str = params.to_json() json_dict = json.loads(json_str) - + # Required fields should be present assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" assert json_dict["mode"] == "THRIFT" assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" - + # Optional fields with None are filtered out by to_json() # This is expected behavior - None values are excluded from JSON output @@ -602,10 +612,10 @@ def test_driver_system_configuration_serialization(self): locale_name="en_US", client_app_name="MyApp", ) - + json_str = sys_config.to_json() json_dict = json.loads(json_str) - + assert json_dict["driver_name"] == "Databricks SQL Connector for Python" assert json_dict["driver_version"] == "3.0.0" assert json_dict["runtime_name"] == "CPython" @@ -622,7 +632,7 @@ def test_telemetry_event_complete_serialization(self): """Test complete TelemetryEvent serialization with all nested objects.""" host_info = HostDetails(host_url="workspace.databricks.com", port=443) proxy_info = HostDetails(host_url="proxy.company.com", port=8080) - + connection_params = DriverConnectionParameters( http_path="/sql/1.0/warehouses/abc123", mode=DatabricksClientType.SEA, @@ -633,7 +643,7 @@ def test_telemetry_event_complete_serialization(self): enable_arrow=True, rows_fetched_per_block=100000, ) - + sys_config = DriverSystemConfiguration( driver_name="Databricks SQL Connector for Python", driver_version="3.0.0", @@ -645,12 +655,12 @@ def test_telemetry_event_complete_serialization(self): os_arch="arm64", char_set_encoding="utf-8", ) - + error_info = DriverErrorInfo( error_name="ConnectionError", stack_trace="Traceback...", ) - + event = TelemetryEvent( session_id="test-session-123", sql_statement_id="test-stmt-456", @@ -660,42 +670,51 @@ def test_telemetry_event_complete_serialization(self): driver_connection_params=connection_params, error_info=error_info, ) - + # Test JSON serialization json_str = event.to_json() assert isinstance(json_str, str) - + # Parse and verify structure parsed = json.loads(json_str) assert parsed["session_id"] == "test-session-123" assert parsed["sql_statement_id"] == "test-stmt-456" assert parsed["operation_latency_ms"] == 1500 assert parsed["auth_type"] == "OAUTH" - + # Verify nested objects - assert parsed["system_configuration"]["driver_name"] == "Databricks SQL Connector for Python" - assert parsed["driver_connection_params"]["http_path"] == "/sql/1.0/warehouses/abc123" + assert ( + parsed["system_configuration"]["driver_name"] + == "Databricks SQL Connector for Python" + ) + assert ( + parsed["driver_connection_params"]["http_path"] + == "/sql/1.0/warehouses/abc123" + ) assert parsed["driver_connection_params"]["use_proxy"] is True - assert parsed["driver_connection_params"]["proxy_host_info"]["host_url"] == "proxy.company.com" + assert ( + parsed["driver_connection_params"]["proxy_host_info"]["host_url"] + == "proxy.company.com" + ) assert parsed["error_info"]["error_name"] == "ConnectionError" def test_json_serialization_excludes_none_values(self): """Test that JSON serialization properly excludes None values.""" host_info = HostDetails(host_url="workspace.databricks.com", port=443) - + params = DriverConnectionParameters( http_path="/sql/1.0/warehouses/abc123", mode=DatabricksClientType.SEA, host_info=host_info, # All optional fields left as None ) - + json_str = params.to_json() parsed = json.loads(json_str) - + # Required fields present assert parsed["http_path"] == "/sql/1.0/warehouses/abc123" - + # None values should be EXCLUDED from JSON (not included as null) # This is the behavior of JsonSerializableMixin assert "auth_mech" not in parsed @@ -704,11 +723,15 @@ def test_json_serialization_excludes_none_values(self): @patch("databricks.sql.client.Session") -@patch("databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers") +@patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" +) class TestConnectionParameterTelemetry: """Tests for connection parameter population in telemetry.""" - def test_connection_with_proxy_populates_telemetry(self, mock_setup_pools, mock_session): + def test_connection_with_proxy_populates_telemetry( + self, mock_setup_pools, mock_session + ): """Test that proxy configuration is captured in telemetry.""" mock_session_instance = MagicMock() mock_session_instance.guid_hex = "test-session-proxy" @@ -718,8 +741,10 @@ def test_connection_with_proxy_populates_telemetry(self, mock_setup_pools, mock_ mock_session_instance.port = 443 mock_session_instance.host = "workspace.databricks.com" mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log" + ) as mock_export: conn = sql.connect( server_hostname="workspace.databricks.com", http_path="/sql/1.0/warehouses/test", @@ -727,23 +752,25 @@ def test_connection_with_proxy_populates_telemetry(self, mock_setup_pools, mock_ enable_telemetry=True, force_enable_telemetry=True, ) - + # Verify export was called mock_export.assert_called_once() call_args = mock_export.call_args - + # Extract driver_connection_params driver_params = call_args.kwargs.get("driver_connection_params") assert driver_params is not None assert isinstance(driver_params, DriverConnectionParameters) - + # Verify fields are populated assert driver_params.http_path == "/sql/1.0/warehouses/test" assert driver_params.mode == DatabricksClientType.SEA assert driver_params.host_info.host_url == "workspace.databricks.com" assert driver_params.host_info.port == 443 - def test_connection_with_azure_params_populates_telemetry(self, mock_setup_pools, mock_session): + def test_connection_with_azure_params_populates_telemetry( + self, mock_setup_pools, mock_session + ): """Test that Azure-specific parameters are captured in telemetry.""" mock_session_instance = MagicMock() mock_session_instance.guid_hex = "test-session-azure" @@ -753,8 +780,10 @@ def test_connection_with_azure_params_populates_telemetry(self, mock_setup_pools mock_session_instance.port = 443 mock_session_instance.host = "workspace.azuredatabricks.net" mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log" + ) as mock_export: conn = sql.connect( server_hostname="workspace.azuredatabricks.net", http_path="/sql/1.0/warehouses/test", @@ -764,15 +793,20 @@ def test_connection_with_azure_params_populates_telemetry(self, mock_setup_pools enable_telemetry=True, force_enable_telemetry=True, ) - + mock_export.assert_called_once() driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - + # Verify Azure fields - assert driver_params.azure_workspace_resource_id == "/subscriptions/test/resourceGroups/test" + assert ( + driver_params.azure_workspace_resource_id + == "/subscriptions/test/resourceGroups/test" + ) assert driver_params.azure_tenant_id == "tenant-123" - def test_connection_populates_arrow_and_performance_params(self, mock_setup_pools, mock_session): + def test_connection_populates_arrow_and_performance_params( + self, mock_setup_pools, mock_session + ): """Test that Arrow and performance parameters are captured in telemetry.""" mock_session_instance = MagicMock() mock_session_instance.guid_hex = "test-session-perf" @@ -782,15 +816,18 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool mock_session_instance.port = 443 mock_session_instance.host = "workspace.databricks.com" mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log" + ) as mock_export: # Import pyarrow availability check try: import pyarrow + arrow_available = True except ImportError: arrow_available = False - + conn = sql.connect( server_hostname="workspace.databricks.com", http_path="/sql/1.0/warehouses/test", @@ -799,10 +836,10 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool enable_telemetry=True, force_enable_telemetry=True, ) - + mock_export.assert_called_once() driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - + # Verify performance fields assert driver_params.enable_arrow == arrow_available assert driver_params.enable_direct_results is True @@ -811,7 +848,9 @@ def test_connection_populates_arrow_and_performance_params(self, mock_setup_pool assert driver_params.async_poll_interval_millis == 2000 assert driver_params.support_many_parameters is True - def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session): + def test_cf_proxy_fields_default_to_false_none( + self, mock_setup_pools, mock_session + ): """Test that CloudFlare proxy fields default to False/None (not yet supported).""" mock_session_instance = MagicMock() mock_session_instance.guid_hex = "test-session-cfproxy" @@ -821,8 +860,10 @@ def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_sess mock_session_instance.port = 443 mock_session_instance.host = "workspace.databricks.com" mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + + with patch( + "databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log" + ) as mock_export: conn = sql.connect( server_hostname="workspace.databricks.com", http_path="/sql/1.0/warehouses/test", @@ -830,7 +871,7 @@ def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_sess enable_telemetry=True, force_enable_telemetry=True, ) - + mock_export.assert_called_once() driver_params = mock_export.call_args.kwargs.get("driver_connection_params") diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py index 6555f1d02..9c0015363 100644 --- a/tests/unit/test_telemetry_push_client.py +++ b/tests/unit/test_telemetry_push_client.py @@ -80,10 +80,13 @@ def test_request_other_error(self): with pytest.raises(ValueError, match="Network error"): self.client.request(HttpMethod.POST, "https://test.com", {}) - @pytest.mark.parametrize("status_code,expected_error", [ - (429, TelemetryRateLimitError), - (503, TelemetryRateLimitError), - ]) + @pytest.mark.parametrize( + "status_code,expected_error", + [ + (429, TelemetryRateLimitError), + (503, TelemetryRateLimitError), + ], + ) def test_request_rate_limit_codes(self, status_code, expected_error): """Test that rate-limit status codes raise TelemetryRateLimitError.""" mock_response = Mock() @@ -97,7 +100,7 @@ def test_request_non_rate_limit_code(self): """Test that non-rate-limit status codes return response.""" mock_response = Mock() mock_response.status = 500 - mock_response.data = b'Server error' + mock_response.data = b"Server error" self.mock_delegate.request.return_value = mock_response response = self.client.request(HttpMethod.POST, "https://test.com", {}) @@ -106,7 +109,9 @@ def test_request_non_rate_limit_code(self): def test_rate_limit_error_logging(self): """Test that rate limit errors are logged with circuit breaker context.""" - with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: mock_response = Mock() mock_response.status = 429 self.mock_delegate.request.return_value = mock_response @@ -121,7 +126,9 @@ def test_rate_limit_error_logging(self): def test_other_error_logging(self): """Test that other errors are logged during wrapping/unwrapping.""" - with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: self.mock_delegate.request.side_effect = ValueError("Network error") with pytest.raises(ValueError, match="Network error"): @@ -137,7 +144,10 @@ def setup_method(self): """Set up test fixtures.""" self.mock_delegate = Mock() self.host = "test-host.example.com" - from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + ) + CircuitBreakerManager._instances.clear() def test_circuit_breaker_opens_after_failures(self): @@ -210,4 +220,5 @@ def test_circuit_breaker_recovers_after_success(self): def test_urllib3_import_fallback(self): """Test that the urllib3 import fallback works correctly.""" from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse + assert BaseHTTPResponse is not None diff --git a/tests/unit/test_telemetry_request_error_handling.py b/tests/unit/test_telemetry_request_error_handling.py index aa31f6628..9dd1b5447 100644 --- a/tests/unit/test_telemetry_request_error_handling.py +++ b/tests/unit/test_telemetry_request_error_handling.py @@ -36,18 +36,26 @@ def client(self, mock_delegate, setup_circuit_breaker): return CircuitBreakerTelemetryPushClient(mock_delegate, "test-host.example.com") @pytest.mark.parametrize("status_code", [429, 503]) - def test_request_error_with_rate_limit_codes(self, client, mock_delegate, status_code): + def test_request_error_with_rate_limit_codes( + self, client, mock_delegate, status_code + ): """Test that RequestError with rate-limit codes raises TelemetryRateLimitError.""" - request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + request_error = RequestError( + "HTTP request failed", context={"http-code": status_code} + ) mock_delegate.request.side_effect = request_error with pytest.raises(TelemetryRateLimitError): client.request(HttpMethod.POST, "https://test.com", {}) @pytest.mark.parametrize("status_code", [500, 400, 404]) - def test_request_error_with_non_rate_limit_codes(self, client, mock_delegate, status_code): + def test_request_error_with_non_rate_limit_codes( + self, client, mock_delegate, status_code + ): """Test that RequestError with non-rate-limit codes raises original RequestError.""" - request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + request_error = RequestError( + "HTTP request failed", context={"http-code": status_code} + ) mock_delegate.request.side_effect = request_error with pytest.raises(RequestError, match="HTTP request failed"): diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7254b66cb..fc354752b 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -205,7 +205,6 @@ def test_headers_are_set(self, t_http_client_class): ) def test_proxy_headers_are_set(self): - from databricks.sql.common.http_utils import create_basic_proxy_auth_headers from urllib.parse import urlparse @@ -618,7 +617,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -662,7 +661,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -707,7 +706,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -859,7 +858,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -875,7 +874,6 @@ def test_handle_execute_response_can_handle_without_direct_results( for resp_type in self.execute_response_types: with self.subTest(resp_type=resp_type): - execute_resp = resp_type( status=self.okay_status, directResults=None, @@ -912,7 +910,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) ( execute_response, @@ -951,7 +949,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -1750,7 +1748,6 @@ def test_handle_execute_response_sets_active_op_handle(self): def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class ): - import thrift, errno from databricks.sql.thrift_api.TCLIService.TCLIService import Client from databricks.sql.exc import RequestError @@ -1827,7 +1824,6 @@ def test_make_request_will_retry_GetOperationStatus( def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos ): - import urllib3.exceptions mock_gos.side_effect = urllib3.exceptions.HTTPError("Read timed out") @@ -2108,7 +2104,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2361,10 +2357,14 @@ def test_col_to_description(self): test_cases = [ ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"), ("normal_col", {}, "string"), - ("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"), + ( + "weird_field", + {b"Spark:DataType:SqlName": b"Some unexpected value"}, + "string", + ), ("missing_field", None, "string"), # None field case ] - + for column_name, field_metadata, expected_type in test_cases: with self.subTest(column_name=column_name, expected_type=expected_type): col = ttypes.TColumnDesc( @@ -2375,7 +2375,9 @@ def test_col_to_description(self): field = ( None if field_metadata is None - else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata) + else pyarrow.field( + column_name, pyarrow.string(), metadata=field_metadata + ) ) result = ThriftDatabricksClient._col_to_description(col, field) @@ -2408,7 +2410,7 @@ def test_hive_schema_to_description(self): [("regular_col", "string")], ), ] - + for columns, arrow_fields, expected_types in test_cases: with self.subTest(arrow_fields=arrow_fields is not None): t_table_schema = ttypes.TTableSchema( diff --git a/tests/unit/test_thrift_field_ids.py b/tests/unit/test_thrift_field_ids.py index a4bba439d..bc018bd05 100644 --- a/tests/unit/test_thrift_field_ids.py +++ b/tests/unit/test_thrift_field_ids.py @@ -36,7 +36,6 @@ def test_all_thrift_field_ids_are_within_allowed_range(self): and hasattr(obj, "thrift_spec") and obj.thrift_spec is not None ): - self._check_class_field_ids(obj, name, violations) if violations: diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py index 4e9ce1bbf..f236c0e98 100644 --- a/tests/unit/test_unified_http_client.py +++ b/tests/unit/test_unified_http_client.py @@ -48,16 +48,19 @@ def http_client(self, client_context): """Create UnifiedHttpClient instance.""" return UnifiedHttpClient(client_context) - @pytest.mark.parametrize("status_code,path", [ - (429, "reason.response"), - (503, "reason.response"), - (500, "direct_response"), - ]) + @pytest.mark.parametrize( + "status_code,path", + [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ], + ) def test_max_retry_error_with_status_codes(self, http_client, status_code, path): """Test MaxRetryError with various status codes and response paths.""" mock_pool = Mock() max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - + if path == "reason.response": max_retry_error.reason = Mock() max_retry_error.reason.response = Mock() @@ -79,12 +82,21 @@ def test_max_retry_error_with_status_codes(self, http_client, status_code, path) assert "http-code" in error.context assert error.context["http-code"] == status_code - @pytest.mark.parametrize("setup_func", [ - lambda e: None, # No setup - error with no attributes - lambda e: setattr(e, "reason", None), # reason=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr - ]) + @pytest.mark.parametrize( + "setup_func", + [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: ( + setattr(e, "reason", Mock()), + setattr(e.reason, "response", None), + ), # reason.response=None + lambda e: ( + setattr(e, "reason", Mock()), + setattr(e.reason, "response", Mock(spec=[])), + ), # No status attr + ], + ) def test_max_retry_error_missing_status(self, http_client, setup_func): """Test MaxRetryError without status code (no crash, empty context).""" mock_pool = Mock() @@ -104,12 +116,12 @@ def test_max_retry_error_prefers_reason_response(self, http_client): """Test that e.reason.response.status is preferred over e.response.status.""" mock_pool = Mock() max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - + # Set both structures with different status codes max_retry_error.reason = Mock() max_retry_error.reason.response = Mock() max_retry_error.reason.response.status = 429 # Should use this - + max_retry_error.response = Mock() max_retry_error.response.status = 500 # Should be ignored diff --git a/tests/unit/test_url_utils.py b/tests/unit/test_url_utils.py index 95f42408d..ff953fb88 100644 --- a/tests/unit/test_url_utils.py +++ b/tests/unit/test_url_utils.py @@ -1,4 +1,5 @@ """Tests for URL utility functions.""" + import pytest from databricks.sql.common.url_utils import normalize_host_with_protocol