sbss.common.distributions.ApproxBernoulli

class sbss.common.distributions.ApproxBernoulli(temperature, probs=None, logits=None, validate_args=None)[source]

Approximation of Bernoulli distribution using a relaxed Bernoulli formulation.

This class provides a straight-through estimator variant of the Relaxed Bernoulli distribution by thresholding at 0.5 while maintaining gradient flow for reparameterized samples.

Parameters:
  • temperature (float) – Relaxation temperature parameter.

  • probs (torch.Tensor, optional) – Probability of success.

  • logits (torch.Tensor, optional) – Log-odds of success.

  • validate_args (bool, optional) – Whether to validate the input arguments.

__init__(temperature, probs=None, logits=None, validate_args=None)[source]

Methods

__init__(temperature[, probs, logits, ...])

cdf(value)

Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution.

entropy()

Returns entropy of distribution, batched over batch_shape.

enumerate_support([expand])

Returns tensor containing all values supported by a discrete distribution.

expand(batch_shape[, _instance])

Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape.

icdf(value)

Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution.

log_prob(value)

Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian.

perplexity()

Returns perplexity of distribution, batched over batch_shape.

rsample([sample_shape])

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

sample([sample_shape])

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

sample_n(n)

Generates n samples or n batches of samples if the distribution parameters are batched.

set_default_validate_args(value)

Sets whether validation is enabled or disabled.

Attributes

arg_constraints

batch_shape

Returns the shape over which parameters are batched.

event_shape

Returns the shape of a single sample (without batching).

has_enumerate_support

has_rsample

logits

mean

Returns the mean of the distribution.

mode

Returns the mode of the distribution.

probs

stddev

Returns the standard deviation of the distribution.

support

temperature

variance

Returns the variance of the distribution.

base_dist