mirror of
https://github.com/PiBrewing/craftbeerpi4.git
synced 2024-12-01 19:24:21 +01:00
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
import pandas.util._test_decorators as td
|
||
|
|
||
|
from pandas import Series, option_context
|
||
|
import pandas._testing as tm
|
||
|
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
|
||
|
# Filter warnings when parallel=True and the function can't be parallelized by Numba
|
||
|
class TestApply:
|
||
|
@pytest.mark.parametrize("jit", [True, False])
|
||
|
def test_numba_vs_cython(self, jit, nogil, parallel, nopython, center):
|
||
|
def f(x, *args):
|
||
|
arg_sum = 0
|
||
|
for arg in args:
|
||
|
arg_sum += arg
|
||
|
return np.mean(x) + arg_sum
|
||
|
|
||
|
if jit:
|
||
|
import numba
|
||
|
|
||
|
f = numba.jit(f)
|
||
|
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
args = (2,)
|
||
|
|
||
|
s = Series(range(10))
|
||
|
result = s.rolling(2, center=center).apply(
|
||
|
f, args=args, engine="numba", engine_kwargs=engine_kwargs, raw=True
|
||
|
)
|
||
|
expected = s.rolling(2, center=center).apply(
|
||
|
f, engine="cython", args=args, raw=True
|
||
|
)
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
@pytest.mark.parametrize("jit", [True, False])
|
||
|
def test_cache(self, jit, nogil, parallel, nopython):
|
||
|
# Test that the functions are cached correctly if we switch functions
|
||
|
def func_1(x):
|
||
|
return np.mean(x) + 4
|
||
|
|
||
|
def func_2(x):
|
||
|
return np.std(x) * 5
|
||
|
|
||
|
if jit:
|
||
|
import numba
|
||
|
|
||
|
func_1 = numba.jit(func_1)
|
||
|
func_2 = numba.jit(func_2)
|
||
|
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
|
||
|
roll = Series(range(10)).rolling(2)
|
||
|
result = roll.apply(
|
||
|
func_1, engine="numba", engine_kwargs=engine_kwargs, raw=True
|
||
|
)
|
||
|
expected = roll.apply(func_1, engine="cython", raw=True)
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
# func_1 should be in the cache now
|
||
|
assert (func_1, "rolling_apply") in NUMBA_FUNC_CACHE
|
||
|
|
||
|
result = roll.apply(
|
||
|
func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True
|
||
|
)
|
||
|
expected = roll.apply(func_2, engine="cython", raw=True)
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
# This run should use the cached func_1
|
||
|
result = roll.apply(
|
||
|
func_1, engine="numba", engine_kwargs=engine_kwargs, raw=True
|
||
|
)
|
||
|
expected = roll.apply(func_1, engine="cython", raw=True)
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
def test_use_global_config():
|
||
|
def f(x):
|
||
|
return np.mean(x) + 2
|
||
|
|
||
|
s = Series(range(10))
|
||
|
with option_context("compute.use_numba", True):
|
||
|
result = s.rolling(2).apply(f, engine=None, raw=True)
|
||
|
expected = s.rolling(2).apply(f, engine="numba", raw=True)
|
||
|
tm.assert_series_equal(expected, result)
|