ifnt#
Execute runtime assertions, indexing checks, and more if jax
code is not traced.
>>> import ifnt
>>> import jax
>>> from jax import numpy as jnp
>>>
>>> def safe_log(x):
... ifnt.testing.assert_array_less(0, x)
... return jnp.log(x)
>>>
>>> safe_log(-1)
Traceback (most recent call last):
...
AssertionError: Arrays are not less-ordered
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 1
Max relative difference: 1.
x: array(0)
y: array(-1)
>>> jax.jit(safe_log)(-1)
Array(nan, dtype=float32, weak_type=True)
Installation#
$ pip install jax-ifnt
Relationship to chex#
DeepMind’s chex provides similar, often complementary, assertions. While chex requires runtime assertions to be “functionalized” with chex.chexify
, ifnt will skip assertions in traced code. This facilitates, for example, verifying that indices are not out of bounds.