ifnt#

https://github.com/tillahoffmann/ifnt/actions/workflows/build.yml/badge.svg https://readthedocs.org/projects/ifnt/badge/?version=latest

Execute runtime assertions, indexing checks, and more if jax code is not traced.

>>> import ifnt
>>> import jax
>>> from jax import numpy as jnp
>>>
>>> def safe_log(x):
...     ifnt.testing.assert_array_less(0, x)
...     return jnp.log(x)
>>>
>>> safe_log(-1)
Traceback (most recent call last):
...
AssertionError: Arrays are not less-ordered

Mismatched elements: 1 / 1 (100%)
Max absolute difference: 1
Max relative difference: 1.
x: array(0)
y: array(-1)
>>> jax.jit(safe_log)(-1)
Array(nan, dtype=float32, weak_type=True)

Installation#

$ pip install jax-ifnt

Relationship to chex#

DeepMind’s chex provides similar, often complementary, assertions. While chex requires runtime assertions to be “functionalized” with chex.chexify, ifnt will skip assertions in traced code. This facilitates, for example, verifying that indices are not out of bounds.