From e95466774e1b3cfe4312710624ecfd026d6e3c7a Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 21 Jan 2026 00:24:14 -0500 Subject: [PATCH] Fixed #1116 Avoid Import Fastmath Functions --- pyproject.toml | 2 ++ stumpy/__init__.py | 36 +++++++++++++++++++++++------------- tests/test_fastmath.py | 16 +++++++++++++++- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5fc40c465..2a2d5dfbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,3 +112,5 @@ tbb = ">=2019.5" tests = "./test.sh" coverage = "./test.sh coverage" docs = "cd docs && ./setup.sh" +black = 'black --exclude=".*\.ipynb" --extend-exclude=".venv|.pixi" --diff ./' +isort = 'isort --profile black --skip .venv --skip .pixi ./' diff --git a/stumpy/__init__.py b/stumpy/__init__.py index 870912b15..5df46ff8c 100644 --- a/stumpy/__init__.py +++ b/stumpy/__init__.py @@ -1,9 +1,9 @@ -import importlib +import ast import os.path +import pathlib from importlib.metadata import distribution from site import getsitepackages -import numba from numba import cuda from . import cache, config @@ -38,14 +38,27 @@ # Get the default fastmath flags for all njit functions # and update the _STUMPY_DEFAULTS dictionary -if not numba.config.DISABLE_JIT: # pragma: no cover - njit_funcs = cache.get_njit_funcs() - for module_name, func_name in njit_funcs: - module = importlib.import_module(f".{module_name}", package="stumpy") - func = getattr(module, func_name) - key = module_name + "." + func_name # e.g., core._mass - key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS - config._STUMPY_DEFAULTS[key] = func.targetoptions["fastmath"] + +def _get_fastmath_value(module_name, func_name): # pragma: no cover + fname = module_name + ".py" + fname = pathlib.Path(__file__).parent / fname + with open(fname, "r", encoding="utf-8") as f: + src = f.read() + tree = ast.parse(src) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == func_name: + for dec in node.decorator_list: + for kw in dec.keywords: + if kw.arg == "fastmath": + fastmath_flag = ast.get_source_segment(src, kw.value) + return eval(fastmath_flag) + + +njit_funcs = cache.get_njit_funcs() +for module_name, func_name in njit_funcs: + key = module_name + "." + func_name # e.g., core._mass + key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS + config._STUMPY_DEFAULTS[key] = _get_fastmath_value(module_name, func_name) if cuda.is_available(): from .gpu_aamp import gpu_aamp # noqa: F401 @@ -72,9 +85,6 @@ core._gpu_searchsorted_left = core._gpu_searchsorted_left_driver_not_found core._gpu_searchsorted_right = core._gpu_searchsorted_right_driver_not_found - import ast - import pathlib - # Fix GPU-STUMP Docs gpu_stump.__doc__ = "" filepath = pathlib.Path(__file__).parent / "gpu_stump.py" diff --git a/tests/test_fastmath.py b/tests/test_fastmath.py index a16bb6898..92366069f 100644 --- a/tests/test_fastmath.py +++ b/tests/test_fastmath.py @@ -1,7 +1,10 @@ +import importlib + import numba import numpy as np +import pytest -from stumpy import cache, fastmath +from stumpy import _get_fastmath_value, cache, fastmath def test_set(): @@ -50,3 +53,14 @@ def test_reset(): assert np.isnan(fastmath._add_assoc(0.0, np.inf)) else: # pragma: no cover assert fastmath._add_assoc(0.0, np.inf) == 0.0 + + +@pytest.mark.skipif(numba.config.DISABLE_JIT, reason="JIT Disabled") +def test_get_fastmath_value(): # pragma: no cover + njit_funcs = cache.get_njit_funcs() + for module_name, func_name in njit_funcs: + module = importlib.import_module(f".{module_name}", package="stumpy") + func = getattr(module, func_name) + ref = func.targetoptions["fastmath"] + cmp = _get_fastmath_value(module_name, func_name) + assert ref == cmp