MAINT: Check essential data functions by mtsokol · Pull Request #380 · data-apis/array-api-tests · GitHub
Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions array_api_tests/test_creation_functions.py
1 change: 1 addition & 0 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def test_repeat(x, kw, data):

reshape_shape = st.shared(hh.shapes(), key="reshape_shape")

@pytest.mark.has_setup_funcs
@pytest.mark.unvectorized
@given(
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
Expand Down
29 changes: 29 additions & 0 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def test_acosh(x):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
@given(data=st.data())
def test_add(ctx, data):
Expand Down Expand Up @@ -854,6 +855,7 @@ def test_atanh(x):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
)
Expand All @@ -873,6 +875,7 @@ def test_bitwise_and(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
)
Expand All @@ -895,6 +898,7 @@ def test_bitwise_left_shift(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes)
)
Expand All @@ -913,6 +917,7 @@ def test_bitwise_invert(ctx, data):
unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}")


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
)
Expand All @@ -932,6 +937,7 @@ def test_bitwise_or(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
)
Expand All @@ -953,6 +959,7 @@ def test_bitwise_right_shift(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
)
Expand Down Expand Up @@ -981,6 +988,7 @@ def test_ceil(x):


@pytest.mark.min_version("2023.12")
@pytest.mark.has_setup_funcs
@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
def test_clip(x, data):
# Ensure that if both min and max are arrays that all three of x, min, max
Expand Down Expand Up @@ -1145,6 +1153,7 @@ def test_cosh(x):
unary_assert_against_refimpl("cosh", x, out, refimpl)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes))
@given(data=st.data())
def test_divide(ctx, data):
Expand All @@ -1168,6 +1177,7 @@ def test_divide(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
@given(data=st.data())
def test_equal(ctx, data):
Expand Down Expand Up @@ -1242,6 +1252,7 @@ def refimpl(z):
unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes))
@given(data=st.data())
def test_floor_divide(ctx, data):
Expand All @@ -1261,6 +1272,7 @@ def test_floor_divide(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes))
@given(data=st.data())
def test_greater(ctx, data):
Expand All @@ -1281,6 +1293,7 @@ def test_greater(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes))
@given(data=st.data())
def test_greater_equal(ctx, data):
Expand Down Expand Up @@ -1352,6 +1365,7 @@ def test_isnan(x):
unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes))
@given(data=st.data())
def test_less(ctx, data):
Expand All @@ -1372,6 +1386,7 @@ def test_less(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes))
@given(data=st.data())
def test_less_equal(ctx, data):
Expand Down Expand Up @@ -1463,6 +1478,7 @@ def logaddexp_refimpl(l: float, r: float) -> float:


@pytest.mark.min_version("2023.12")
@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_logaddexp(x1, x2):
out = xp.logaddexp(x1, x2)
Expand All @@ -1476,6 +1492,7 @@ def test_logaddexp(x1, x2):
)


@pytest.mark.has_setup_funcs
@given(hh.arrays(dtype=xp.bool, shape=hh.shapes()))
def test_logical_not(x):
out = xp.logical_not(x)
Expand All @@ -1486,6 +1503,7 @@ def test_logical_not(x):
)


@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_and(x1, x2):
out = xp.logical_and(x1, x2)
Expand All @@ -1500,6 +1518,7 @@ def test_logical_and(x1, x2):
)


@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_or(x1, x2):
out = xp.logical_or(x1, x2)
Expand All @@ -1514,6 +1533,7 @@ def test_logical_or(x1, x2):
)


@pytest.mark.has_setup_funcs
@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_xor(x1, x2):
out = xp.logical_xor(x1, x2)
Expand Down Expand Up @@ -1546,6 +1566,7 @@ def test_minimum(x1, x2):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@given(data=st.data())
def test_multiply(ctx, data):
Expand Down Expand Up @@ -1577,6 +1598,7 @@ def test_negative(ctx, data):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
@given(data=st.data())
def test_not_equal(ctx, data):
Expand All @@ -1598,6 +1620,7 @@ def test_not_equal(ctx, data):


@pytest.mark.min_version("2024.12")
@pytest.mark.has_setup_funcs
@given(
shapes=hh.two_mutually_broadcastable_shapes,
dtype=hh.real_floating_dtypes,
Expand All @@ -1617,6 +1640,8 @@ def test_nextafter(shapes, dtype, data):
out=out
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes))
@given(data=st.data())
def test_positive(ctx, data):
Expand All @@ -1629,6 +1654,7 @@ def test_positive(ctx, data):
ph.assert_array_elements(ctx.func_name, out=out, expected=x)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
@given(data=st.data())
def test_pow(ctx, data):
Expand Down Expand Up @@ -1676,6 +1702,7 @@ def test_reciprocal(x):


@pytest.mark.skip(reason="flaky")
@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))
@given(data=st.data())
def test_remainder(ctx, data):
Expand Down Expand Up @@ -1770,6 +1797,7 @@ def test_sqrt(x):
)


@pytest.mark.has_setup_funcs
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
@given(data=st.data())
def test_subtract(ctx, data):
Expand Down Expand Up @@ -1923,6 +1951,7 @@ def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):


@pytest.mark.min_version("2024.12")
@pytest.mark.has_setup_funcs
@pytest.mark.unvectorized
@given(
x1x2=hh.array_and_py_scalar([xp.int32]),
Expand Down
1 change: 1 addition & 0 deletions array_api_tests/test_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_any(x, data):

@pytest.mark.unvectorized
@pytest.mark.min_version("2024.12")
@pytest.mark.has_setup_funcs
@given(
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
data=st.data(),
Expand Down
15 changes: 15 additions & 0 deletions conftest.py