βš™οΈ ifnt.util#

ifnt.util includes basic functionality for constructing functions that are skipped if traced, conditioning on traced values, safe indexing, and printing non=traced tensors.

ifnt.util.broadcast_over_dict(func: F) F#

Broadcast a function over values of a dictionary.

Parameters:

func – Function to broadcast.

Example

>>> from functools import singledispatch
>>> @ifnt.util.broadcast_over_dict
... @singledispatch
... def add_one(x):
...     return x + 1
>>> add_one({"a": 1, "b": 2})
{'a': 2, 'b': 3}
ifnt.util.disable(do_disable: bool = True)#

Disable all ifnt behavior even if values are not traced.

Parameters:

do_disable – Disable ifnt if truth-y.

Example

>>> ifnt.testing.assert_allclose(1, 2)
Traceback (most recent call last):
...
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
Max absolute difference among violations: 1
Max relative difference among violations: 0.5
ACTUAL: array(1)
DESIRED: array(2)
>>> with ifnt.disable():
...     ifnt.testing.assert_allclose(1, 2)
>>> ifnt.testing.assert_allclose(1, 2)
Traceback (most recent call last):
...
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
Max absolute difference among violations: 1
Max relative difference among violations: 0.5
ACTUAL: array(1)
DESIRED: array(2)
class ifnt.util.index_guard(x)#

Safe indexing that checks out of bounds when not traced.

Parameters:

x – Array to guard.

Example

>>> x = jnp.arange(3)
>>> x[7]
Array(2, dtype=int32)
>>> ifnt.index_guard(x)[7]
Traceback (most recent call last):
...
IndexError: index 7 is out of bounds for axis 0 with size 3
>>> ifnt.index_guard(x).at[1].set(7)
Array([0, 7, 2], dtype=int32)
ifnt.util.is_traced(*xs: Any) bool#

Return if any of the arguments are traced.

Warning

is_traced() always returns True if the IFNT_DISABLED environment variable is set or the function is called within a disable() context. Consequently any code that is not executed when traced is skipped.

Parameters:

xs – Value or values to check.

Returns:

If any of the values are traced.

Example

>>> def f(x):
...     return ifnt.is_traced(x)
>>> x = jnp.zeros(3)
>>> f(x)
False
>>> jax.jit(f)(x)
Array(True, dtype=bool)
>>> with ifnt.disable():
...     f(x)
True
ifnt.util.print(*objects, sep=' ', end='\n', file=None, flush=False) None#

If none of the arguments are traced, print objects separated by sep followed by end to stdout or a file.

Parameters:
  • *objects – Objects to print.

  • sep – Separator between objects.

  • end – End string.

  • file – Stream with write method to print objects to.

  • flush – Forcibly flush the stream.

ifnt.util.raise_if_traced(func: F) F#

Raise an error if any of the function’s arguments are traced.

Parameters:

func – Function to fail if any of its arguments are traced.

Example

>>> @ifnt.raise_if_traced
... def multiply(x):
...     return 2 * x
>>>
>>> multiply(jnp.arange(3))
Array([0, 2, 4], dtype=int32)
>>> jax.jit(multiply)(jnp.arange(3))
Traceback (most recent call last):
...
RuntimeError: Cannot execute `multiply` because one or more of its arguments are
traced.
ifnt.util.skip_if_traced(func: F) F#

Skip a function if any of its arguments are traced. The decorated function does not return a value, even if the original function did.

Parameters:

func – Function to skip if any of its arguments are traced.

Example

>>> @ifnt.skip_if_traced
... def assert_positive(x):
...     assert x.min() > 0
>>>
>>> assert_positive(-jnp.zeros(5))
Traceback (most recent call last):
...
AssertionError
>>> jax.jit(assert_positive)(-jnp.zeros(5))