🎲 ifnt.random#

ifnt.random facilitates stateful random number generation to avoid repeated calls to jax.random.split().

class ifnt.random.JaxRandomState(seed: int | None = None)#

Utility class for sampling random variables using the JAX interface with automatic random state handling.

Parameters:

seed – Initial random number generator seed or None to use a time-based seed.

Warning

This implementation is stateful and does not support jax.jit() compilation.

Example

>>> rng = ifnt.random.JaxRandomState(7)
>>> rng.normal()
Array(-1.4622004, dtype=float32)
>>> rng.normal()
Array(2.0224454, dtype=float32)
ball(d: int, p: float = 2, shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>)#

Sample uniformly from the unit Lp ball.

Reference: https://arxiv.org/abs/math/0503650.

Parameters:
  • d – a nonnegative int representing the dimensionality of the ball.

  • p – a float representing the p parameter of the Lp norm.

  • shape – optional, the batch dimensions of the result. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array of shape (*shape, d) and specified dtype.

bernoulli(p: RealArray = np.float32(0.5), shape: Shape | NamedShape | None = None) Array#

Sample Bernoulli random values with given shape and mean.

The values are distributed according to the probability mass function:

\[f(k; p) = p^k(1 - p)^{1 - k}\]

where \(k \in \{0, 1\}\) and \(0 \le p \le 1\).

Parameters:
  • p – optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with shape. Default 0.5.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with p.shape. The default (None) produces a result shape equal to p.shape.

Returns:

A random array with boolean dtype and shape given by shape if shape is not None, or else p.shape.

beta(a: RealArray, b: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Beta random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}\]

on the domain \(0 \le x \le 1\).

Parameters:
  • a – a float or array of floats broadcast-compatible with shape representing the first parameter β€œalpha”.

  • b – a float or array of floats broadcast-compatible with shape representing the second parameter β€œbeta”.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a and b. The default (None) produces a result shape by broadcasting a and b.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting a and b.

binomial(n: RealArray, p: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Binomial random values with given shape and float dtype.

The values are returned according to the probability mass function:

\[f(k;n,p) = \binom{n}{k}p^k(1-p)^{n-k}\]

on the domain \(0 < p < 1\), and where \(n\) is a nonnegative integer representing the number of trials and \(p\) is a float representing the probability of success of an individual trial.

Parameters:
  • n – a float or array of floats broadcast-compatible with shape representing the number of trials.

  • p – a float or array of floats broadcast-compatible with shape representing the probability of success of an individual trial.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with n and p. The default (None) produces a result shape equal to np.broadcast(n, p).shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by np.broadcast(n, p).shape.

bits(shape: Shape = (), dtype: DTypeLikeUInt | None = None) Array#

Sample uniform bits in the form of unsigned integers.

Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, an unsigned integer dtype for the returned values (default uint64 if jax_enable_x64 is true, otherwise uint32).

Returns:

A random array with the specified shape and dtype.

categorical(logits: RealArray, axis: int = -1, shape: Shape | None = None) Array#

Sample random values from categorical distributions.

Parameters:
  • logits – Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.

  • axis – Axis along which logits belong to the same categorical distribution.

  • shape – Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with np.delete(logits.shape, axis). The default (None) produces a result shape equal to np.delete(logits.shape, axis).

Returns:

A random array with int dtype and shape given by shape if shape is not None, or else np.delete(logits.shape, axis).

cauchy(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample Cauchy random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) \propto \frac{1}{x^2 + 1}\]

on the domain \(-\infty < x < \infty\)

Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

chisquare(df: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Chisquare random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x; \nu) \propto x^{\nu/2 - 1}e^{-x/2}\]

on the domain \(0 < x < \infty\), where \(\nu > 0\) represents the degrees of freedom, given by the parameter df.

Parameters:
  • df – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with df. The default (None) produces a result shape equal to df.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

choice(a: int | ArrayLike, shape: Shape = (), replace: bool = True, p: RealArray | None = None, axis: int = 0) Array#

Generates a random sample from a given array.

Warning

If p has fewer non-zero elements than the requested number of samples, as specified in shape, and replace=False, the output of this function is ill-defined. Please make sure to use appropriate inputs.

Parameters:
  • a – array or int. If an ndarray, a random sample is generated from its elements. If an int, the random sample is generated as if a were arange(a).

  • shape – tuple of ints, optional. Output shape. If the given shape is, e.g., (m, n), then m * n samples are drawn. Default is (), in which case a single value is returned.

  • replace – boolean. Whether the sample is with or without replacement. Default is True.

  • p – 1-D array-like, The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a.

  • axis – int, optional. The axis along which the selection is performed. The default, 0, selects by row.

Returns:

An array of shape shape containing samples from a.

dirichlet(alpha: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Dirichlet random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(\{x_i\}; \{\alpha_i\}) \propto \prod_{i=1}^k x_i^{\alpha_i - 1}\]

Where \(k\) is the dimension, and \(\{x_i\}\) satisfies

\[\sum_{i=1}^k x_i = 1\]

and \(0 \le x_i \le 1\) for all \(x_i\).

Parameters:
  • alpha – an array of shape (..., n) used as the concentration parameter of the random variables.

  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value n. Must be broadcast-compatible with alpha.shape[:-1]. The default (None) produces a result shape equal to alpha.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and shape given by shape + (alpha.shape[-1],) if shape is not None, or else alpha.shape.

double_sided_maxwell(loc: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, scale: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample from a double sided Maxwell distribution.

The values are distributed according to the probability density function:

\[f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2}\]

where \(z = (x - \mu) / \sigma\), with the center \(\mu\) specified by loc and the scale \(\sigma\) specified by scale.

Parameters:
  • key – a PRNG key.

  • loc – The location parameter of the distribution.

  • scale – The scale parameter of the distribution.

  • shape – The shape added to the parameters loc and scale broadcastable shape.

  • dtype – The type used for samples.

Returns:

A jnp.array of samples.

exponential(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample Exponential random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = e^{-x}\]

on the domain \(0 \le x < \infty\).

Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

f(dfnum: RealArray, dfden: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample F-distribution random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x; \nu_1, \nu_2) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{ -(\nu_1 + \nu_2) / 2}\]

on the domain \(0 < x < \infty\). Here \(\nu_1\) is the degrees of freedom of the numerator (dfnum), and \(\nu_2\) is the degrees of freedom of the denominator (dfden).

Parameters:
  • dfnum – a float or array of floats broadcast-compatible with shape representing the numerator’s df of the distribution.

  • dfden – a float or array of floats broadcast-compatible with shape representing the denominator’s df of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with dfnum and dfden. The default (None) produces a result shape equal to dfnum.shape, and dfden.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

fold_in(data: Array | ndarray | bool | number | bool | int | float | complex) Array#

Folds in data to a PRNG key to form a new PRNG key.

Parameters:
  • key – a PRNG key (from key, split, fold_in).

  • data – a 32-bit integer representing data to be folded into the key.

Returns:

A new PRNG key that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values.

gamma(a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Gamma random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x;a) \propto x^{a - 1} e^{-x}\]

on the domain \(0 \le x < \infty\), with \(a > 0\).

This is the standard gamma density, with a unit scale/rate parameter. Dividing the sample output by the rate is equivalent to sampling from gamma(a, rate), and multiplying the sample output by the scale is equivalent to sampling from gamma(a, scale).

Parameters:
  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

See also

loggammasample gamma values in log-space, which can provide improved

accuracy for small values of a.

generalized_normal(p: float, shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample from the generalized normal distribution.

The values are returned according to the probability density function:

\[f(x;p) \propto e^{-|x|^p}\]

on the domain \(-\infty < x < \infty\), where \(p > 0\) is the shape parameter.

Parameters:
  • p – a float representing the shape parameter.

  • shape – optional, the batch dimensions of the result. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

geometric(p: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = <class 'int'>) Array#

Sample Geometric random values with given shape and float dtype.

The values are returned according to the probability mass function:

\[f(k;p) = p(1-p)^{k-1}\]

on the domain \(0 < p < 1\).

Parameters:
  • p – a float or array of floats broadcast-compatible with shape representing the probability of success of an individual trial.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with p. The default (None) produces a result shape equal to np.shape(p).

  • dtype – optional, a int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by p.shape.

get_key() Array#

Get a random key and update the state.

gumbel(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample Gumbel random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = e^{-(x + e^{-x})}\]
Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

laplace(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample Laplace random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = \frac{1}{2}e^{-|x|}\]
Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

loggamma(a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample log-gamma random values with given shape and float dtype.

This function is implemented such that the following will hold for a dtype-appropriate tolerance:

np.testing.assert_allclose(jnp.exp(loggamma(*args)), gamma(*args), rtol=rtol)

The benefit of log-gamma is that for samples very close to zero (which occur frequently when a << 1) sampling in log space provides better precision.

Parameters:
  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

See also

gamma : standard gamma sampler.

logistic(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample logistic random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}\]
Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

lognormal(sigma: RealArray = np.float32(1.0), shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample lognormal random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = \frac{1}{x\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(\log x)^2}{2\sigma^2}\right)\]

on the domain \(x > 0\).

Parameters:
  • sigma – a float or array of floats broadcast-compatible with shape representing the standard deviation of the underlying normal distribution. Default 1.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. The default (None) produces a result shape equal to ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape.

maxwell(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample from a one sided Maxwell distribution.

The values are distributed according to the probability density function:

\[f(x) \propto x^2 e^{-x^2 / 2}\]

on the domain \(0 \le x < \infty\).

Parameters:
  • key – a PRNG key.

  • shape – The shape of the returned samples.

  • dtype – The type used for samples.

Returns:

A jnp.array of samples, of shape shape.

multivariate_normal(mean: RealArray, cov: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat | None = None, method: str = 'cholesky') Array#

Sample multivariate normal random values with given mean and covariance.

The values are returned according to the probability density function:

\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]

where \(k\) is the dimension, \(\mu\) is the mean (given by mean) and \(\Sigma\) is the covariance matrix (given by cov).

Parameters:
  • mean – a mean vector of shape (..., n).

  • cov – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • method – optional, a method to compute the factor of cov. Must be one of β€˜svd’, β€˜eigh’, and β€˜cholesky’. Default β€˜cholesky’. For singular covariance matrices, use β€˜svd’ or β€˜eigh’.

Returns:

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].

normal(shape: Shape | NamedShape = (), dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample standard normal random values with given shape and float dtype.

The values are returned according to the probability density function:

\[f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}\]

on the domain \(-\infty < x < \infty\)

Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

orthogonal(n: int, shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample uniformly from the orthogonal group O(n).

If the dtype is complex, sample uniformly from the unitary group U(n).

Parameters:
  • n – an integer indicating the resulting dimension.

  • shape – optional, the batch dimensions of the result. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array of shape (*shape, n, n) and specified dtype.

pareto(b: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Pareto random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x; b) = b / x^{b + 1}\]

on the domain \(1 \le x < \infty\) with \(b > 0\)

Parameters:
  • b – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with b. The default (None) produces a result shape equal to b.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by b.shape.

permutation(x: int | ArrayLike, axis: int = 0, independent: bool = False) Array#

Returns a randomly permuted array or range.

Parameters:
  • x – int or array. If x is an integer, randomly shuffle np.arange(x). If x is an array, randomly shuffle its elements.

  • axis – int, optional. The axis which x is shuffled along. Default is 0.

  • independent – bool, optional. If set to True, each individual vector along the given axis is shuffled independently. Default is False.

Returns:

A shuffled version of x or array range

poisson(lam: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = <class 'int'>) Array#

Sample Poisson random values with given shape and integer dtype.

The values are distributed according to the probability mass function:

\[f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!}\]

Where k is a non-negative integer and \(\lambda > 0\).

Parameters:
  • lam – rate parameter (mean of the distribution), must be >= 0. Must be broadcast-compatible with shape

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default (None) produces a result shape equal to lam.shape.

  • dtype – optional, a integer dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by ``lam.shape.

rademacher(shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'int'>) Array#

Sample from a Rademacher distribution.

The values are distributed according to the probability mass function:

\[f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))\]

on the domain \(k \in \{-1, 1\}\), where \(\delta(x)\) is the dirac delta function.

Parameters:
  • key – a PRNG key.

  • shape – The shape of the returned samples. Default ().

  • dtype – The type used for samples.

Returns:

A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.

randint(shape: ~collections.abc.Sequence[int], minval: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, maxval: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'int'>) Array#

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • shape – a tuple of nonnegative integers representing the shape.

  • minval – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.

  • maxval – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.

  • dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns:

A random array with the specified shape and dtype.

rayleigh(scale: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Rayleigh random values with given shape and float dtype.

The values are returned according to the probability density function:

\[f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)}\]

on the domain \(-\infty < x < \infty\), and where \(\sigma > 0\) is the scale parameter of the distribution.

Parameters:
  • scale – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with scale. The default (None) produces a result shape equal to scale.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by scale.shape.

t(df: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample Student’s t random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(t; \nu) \propto \left(1 + \frac{t^2}{\nu}\right)^{-(\nu + 1)/2}\]

Where \(\nu > 0\) is the degrees of freedom, given by the parameter df.

Parameters:
  • df – a float or array of floats broadcast-compatible with shape representing the degrees of freedom parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with df. The default (None) produces a result shape equal to df.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

triangular(left: RealArray, mode: RealArray, right: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Triangular random values with given shape and float dtype.

The values are returned according to the probability density function:

\[\begin{split}f(x; a, b, c) = \frac{2}{c-a} \left\{ \begin{array}{ll} \frac{x-a}{b-a} & a \leq x \leq b \\ \frac{c-x}{c-b} & b \leq x \leq c \end{array} \right.\end{split}\]

on the domain \(a \leq x \leq c\).

Parameters:
  • left – a float or array of floats broadcast-compatible with shape representing the lower limit parameter of the distribution.

  • mode – a float or array of floats broadcast-compatible with shape representing the peak value parameter of the distribution, value must fulfill the condition left <= mode <= right.

  • right – a float or array of floats broadcast-compatible with shape representing the upper limit parameter of the distribution, must be larger than left.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with left,``mode`` and right. The default (None) produces a result shape equal to left.shape, mode.shape and right.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by left.shape, mode.shape and right.shape.

truncated_normal(lower: RealArray, upper: RealArray, shape: Shape | NamedShape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample truncated standard normal random values with given shape and dtype.

The values are returned according to the probability density function:

\[f(x) \propto e^{-x^2/2}\]

on the domain \(\rm{lower} < x < \rm{upper}\).

Parameters:
  • lower – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.

  • upper – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).

uniform(shape: Shape | NamedShape = (), dtype: DTypeLikeFloat = <class 'float'>, minval: RealArray = 0.0, maxval: RealArray = 1.0) Array#

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).

  • maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).

Returns:

A random array with the specified shape and dtype.

wald(mean: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = <class 'float'>) Array#

Sample Wald random values with given shape and float dtype.

The values are returned according to the probability density function:

\[f(x;\mu) = \frac{1}{\sqrt{2\pi x^3}} \exp\left(-\frac{(x - \mu)^2}{2\mu^2 x}\right)\]

on the domain \(-\infty < x < \infty\), and where \(\mu > 0\) is the location parameter of the distribution.

Parameters:
  • mean – a float or array of floats broadcast-compatible with shape representing the mean parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with mean. The default (None) produces a result shape equal to np.shape(mean).

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by mean.shape.

weibull_min(scale: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, concentration: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | bool | int | float | complex, shape: ~collections.abc.Sequence[int] = (), dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>) Array#

Sample from a Weibull distribution.

The values are distributed according to the probability density function:

\[f(x;\sigma,c) \propto x^{c - 1} \exp(-(x / \sigma)^c)\]

on the domain \(0 < x < \infty\), where \(c > 0\) is the concentration parameter, and \(\sigma > 0\) is the scale parameter.

Parameters:
  • key – a PRNG key.

  • scale – The scale parameter of the distribution.

  • concentration – The concentration parameter of the distribution.

  • shape – The shape added to the parameters loc and scale broadcastable shape.

  • dtype – The type used for samples.

Returns:

A jnp.array of samples.

ifnt.random.keys(seed)#

Random key generator.

Parameters:

seed – Initial seed.

Example

>>> keys = ifnt.random.keys(9)
>>> next(keys)
Array((), dtype=key<fry>) overlaying: [4109519897 3077142452]
>>> next(keys)
Array((), dtype=key<fry>) overlaying: [3656642974 2192743943]
>>> keys = ifnt.random.keys(jax.random.key(9))
>>> next(keys)
Array((), dtype=key<fry>) overlaying: [4109519897 3077142452]
>>> next(keys)
Array((), dtype=key<fry>) overlaying: [3656642974 2192743943]