Drop-Path#
eqxvision.layers.DropPath
#
Effectively dropping a sample from the call.
Often used inside a network along side a residual connection.
Equivalent to torchvision.stochastic_depth
.
__init__(self, p: float = 0.0, inference: bool = False, mode = 'global')
#
Arguments:
p
: The probability to drop a sample entirely during forward passinference
: Defaults toFalse
. IfTrue
, then the input is returned unchanged This may be toggled withequinox.tree_inference
mode
: Can be set toglobal
orlocal
. Ifglobal
, the whole input is dropped or retained. Iflocal
, then the decision on each input unit is computed independently. Defaults toglobal
Note
For mode = local
, an input (channels, dim_0, dim_1, ...)
is reshaped and transposed to
(channels, dims).transpose()
. For each dim x channels
element,
the decision to drop/keep is made independently.
__call__(self, x, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x
: An any-dimensional JAX array to dropkey
: Ajax.random.PRNGKey
used to provide randomness for calculating which elements to dropout. (Keyword only argument.)