Skip to content

Drop-Path#

eqxvision.layers.DropPath #

Effectively dropping a sample from the call. Often used inside a network along side a residual connection. Equivalent to torchvision.stochastic_depth.

__init__(self, p: float = 0.0, inference: bool = False, mode = 'global') #

Arguments:

  • p: The probability to drop a sample entirely during forward pass
  • inference: Defaults to False. If True, then the input is returned unchanged This may be toggled with equinox.tree_inference
  • mode: Can be set to global or local. If global, the whole input is dropped or retained. If local, then the decision on each input unit is computed independently. Defaults to global

Note

For mode = local, an input (channels, dim_0, dim_1, ...) is reshaped and transposed to (channels, dims).transpose(). For each dim x channels element, the decision to drop/keep is made independently.

__call__(self, x, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: An any-dimensional JAX array to drop
  • key: A jax.random.PRNGKey used to provide randomness for calculating which elements to dropout. (Keyword only argument.)