Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@ tbb = ">=2019.5"
tests = "./test.sh"
coverage = "./test.sh coverage"
docs = "cd docs && ./setup.sh"
black = 'black --exclude=".*\.ipynb" --extend-exclude=".venv|.pixi" --diff ./'
isort = 'isort --profile black --skip .venv --skip .pixi ./'
36 changes: 23 additions & 13 deletions stumpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import importlib
import ast
import os.path
import pathlib
from importlib.metadata import distribution
from site import getsitepackages

import numba
from numba import cuda

from . import cache, config
Expand Down Expand Up @@ -38,14 +38,27 @@
# Get the default fastmath flags for all njit functions
# and update the _STUMPY_DEFAULTS dictionary

if not numba.config.DISABLE_JIT: # pragma: no cover
njit_funcs = cache.get_njit_funcs()
for module_name, func_name in njit_funcs:
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
key = module_name + "." + func_name # e.g., core._mass
key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS
config._STUMPY_DEFAULTS[key] = func.targetoptions["fastmath"]

def _get_fastmath_value(module_name, func_name): # pragma: no cover
fname = module_name + ".py"
fname = pathlib.Path(__file__).parent / fname
with open(fname, "r", encoding="utf-8") as f:
src = f.read()
tree = ast.parse(src)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == func_name:
for dec in node.decorator_list:
for kw in dec.keywords:
if kw.arg == "fastmath":
fastmath_flag = ast.get_source_segment(src, kw.value)
return eval(fastmath_flag)


njit_funcs = cache.get_njit_funcs()
for module_name, func_name in njit_funcs:
key = module_name + "." + func_name # e.g., core._mass
key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS
config._STUMPY_DEFAULTS[key] = _get_fastmath_value(module_name, func_name)

if cuda.is_available():
from .gpu_aamp import gpu_aamp # noqa: F401
Expand All @@ -72,9 +85,6 @@
core._gpu_searchsorted_left = core._gpu_searchsorted_left_driver_not_found
core._gpu_searchsorted_right = core._gpu_searchsorted_right_driver_not_found

import ast
import pathlib

# Fix GPU-STUMP Docs
gpu_stump.__doc__ = ""
filepath = pathlib.Path(__file__).parent / "gpu_stump.py"
Expand Down
16 changes: 15 additions & 1 deletion tests/test_fastmath.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import importlib

import numba
import numpy as np
import pytest

from stumpy import cache, fastmath
from stumpy import _get_fastmath_value, cache, fastmath


def test_set():
Expand Down Expand Up @@ -50,3 +53,14 @@ def test_reset():
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
else: # pragma: no cover
assert fastmath._add_assoc(0.0, np.inf) == 0.0


@pytest.mark.skipif(numba.config.DISABLE_JIT, reason="JIT Disabled")
def test_get_fastmath_value(): # pragma: no cover
njit_funcs = cache.get_njit_funcs()
for module_name, func_name in njit_funcs:
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
ref = func.targetoptions["fastmath"]
cmp = _get_fastmath_value(module_name, func_name)
assert ref == cmp