βοΈ ifnt.util#
ifnt.util includes basic functionality for constructing functions that are skipped if traced, conditioning on traced values, safe indexing, and printing non=traced tensors.
- ifnt.util.broadcast_over_dict(func: F) F#
Broadcast a function over values of a dictionary.
- Parameters:
func β Function to broadcast.
Example
>>> from functools import singledispatch >>> @ifnt.util.broadcast_over_dict ... @singledispatch ... def add_one(x): ... return x + 1 >>> add_one({"a": 1, "b": 2}) {'a': 2, 'b': 3}
- ifnt.util.disable(do_disable: bool = True)#
Disable all
ifntbehavior even if values are not traced.- Parameters:
do_disable β Disable
ifntif truth-y.
Example
>>> ifnt.testing.assert_allclose(1, 2) Traceback (most recent call last): ... AssertionError: Not equal to tolerance rtol=1e-07, atol=0 Mismatched elements: 1 / 1 (100%) Max absolute difference among violations: 1 Max relative difference among violations: 0.5 ACTUAL: array(1) DESIRED: array(2) >>> with ifnt.disable(): ... ifnt.testing.assert_allclose(1, 2) >>> ifnt.testing.assert_allclose(1, 2) Traceback (most recent call last): ... AssertionError: Not equal to tolerance rtol=1e-07, atol=0 Mismatched elements: 1 / 1 (100%) Max absolute difference among violations: 1 Max relative difference among violations: 0.5 ACTUAL: array(1) DESIRED: array(2)
- class ifnt.util.index_guard(x)#
Safe indexing that checks out of bounds when not traced.
- Parameters:
x β Array to guard.
Example
>>> x = jnp.arange(3) >>> x[7] Array(2, dtype=int32) >>> ifnt.index_guard(x)[7] Traceback (most recent call last): ... IndexError: index 7 is out of bounds for axis 0 with size 3 >>> ifnt.index_guard(x).at[1].set(7) Array([0, 7, 2], dtype=int32)
- ifnt.util.is_traced(*xs: Any) bool#
Return if any of the arguments are traced.
Warning
is_traced()always returnsTrueif theIFNT_DISABLEDenvironment variable is set or the function is called within adisable()context. Consequently any code that is not executed when traced is skipped.- Parameters:
xs β Value or values to check.
- Returns:
If any of the values are traced.
Example
>>> def f(x): ... return ifnt.is_traced(x) >>> x = jnp.zeros(3) >>> f(x) False >>> jax.jit(f)(x) Array(True, dtype=bool) >>> with ifnt.disable(): ... f(x) True
- ifnt.util.print(*objects, sep=' ', end='\n', file=None, flush=False) None#
If none of the arguments are traced, print
objectsseparated bysepfollowed byendtostdoutor a file.- Parameters:
*objects β Objects to print.
sep β Separator between objects.
end β End string.
file β Stream with
writemethod to print objects to.flush β Forcibly flush the stream.
- ifnt.util.raise_if_traced(func: F) F#
Raise an error if any of the functionβs arguments are traced.
- Parameters:
func β Function to fail if any of its arguments are traced.
Example
>>> @ifnt.raise_if_traced ... def multiply(x): ... return 2 * x >>> >>> multiply(jnp.arange(3)) Array([0, 2, 4], dtype=int32) >>> jax.jit(multiply)(jnp.arange(3)) Traceback (most recent call last): ... RuntimeError: Cannot execute `multiply` because one or more of its arguments are traced.
- ifnt.util.skip_if_traced(func: F) F#
Skip a function if any of its arguments are traced. The decorated function does not return a value, even if the original function did.
- Parameters:
func β Function to skip if any of its arguments are traced.
Example
>>> @ifnt.skip_if_traced ... def assert_positive(x): ... assert x.min() > 0 >>> >>> assert_positive(-jnp.zeros(5)) Traceback (most recent call last): ... AssertionError >>> jax.jit(assert_positive)(-jnp.zeros(5))